Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add healthcheck support for JetStream #90

Merged
merged 4 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,3 +885,17 @@ async def Decode( # pylint: disable=invalid-overridden-method
)
# Reset buffer after flushed.
buffered_response_list = []

async def HealthCheck(
self,
request: jetstream_pb2.HealthCheckRequest,
context: Optional[grpc.aio.ServicerContext] = None,
) -> jetstream_pb2.HealthCheckResponse:
"""HealthCheck."""
if context is None:
logging.warning(
"LLM orchestrator is being used in offline test mode, and will not"
" respond to gRPC queries - only direct function calls."
)
is_live = self._driver.live
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you share where driver set live status?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing it!

return jetstream_pb2.HealthCheckResponse(is_live=is_live)
9 changes: 9 additions & 0 deletions jetstream/core/proto/jetstream.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ package jetstream_proto;
service Orchestrator {
// Query LLM to generate text or tokens.
rpc Decode(DecodeRequest) returns (stream DecodeResponse) {}
// Checks if the model server is live.
rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse) {}
}

message DecodeRequest {
Expand Down Expand Up @@ -74,4 +76,11 @@ message DecodeResponse {
}
reserved 1;
// Next ID: 4
}

message HealthCheckRequest {}

message HealthCheckResponse {
// Denotes whether the model server is live
bool is_live = 1;
}
46 changes: 24 additions & 22 deletions jetstream/core/proto/jetstream_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,33 @@
_sym_db = _symbol_database.Default()


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\xa7\x02\n\rDecodeRequest\x12\x15\n\rsession_cache\x18\x01 \x01(\t\x12\x10\n\x08priority\x18\x03 \x01(\x05\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x02\x10\x03"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02\x32]\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x62\x06proto3'
)


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto\"\xa7\x02\n\rDecodeRequest\x12\x15\n\rsession_cache\x18\x01 \x01(\t\x12\x10\n\x08priority\x18\x03 \x01(\x05\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x02\x10\x03\"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02\"\x14\n\x12HealthCheckRequest\"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse\"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse\"\x00\x62\x06proto3')

_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(
DESCRIPTOR, "jetstream.core.proto.jetstream_pb2", _globals
)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'jetstream.core.proto.jetstream_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_globals["_DECODEREQUEST"]._serialized_start = 58
_globals["_DECODEREQUEST"]._serialized_end = 353
_globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_start = 274
_globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_end = 301
_globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_start = 303
_globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_end = 336
_globals["_DECODERESPONSE"]._serialized_start = 356
_globals["_DECODERESPONSE"]._serialized_end = 687
_globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_start = 522
_globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_end = 538
_globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_start = 541
_globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 670
_globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 629
_globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 670
_globals["_ORCHESTRATOR"]._serialized_start = 689
_globals["_ORCHESTRATOR"]._serialized_end = 782
_globals['_DECODEREQUEST']._serialized_start=58
_globals['_DECODEREQUEST']._serialized_end=353
_globals['_DECODEREQUEST_TEXTCONTENT']._serialized_start=274
_globals['_DECODEREQUEST_TEXTCONTENT']._serialized_end=301
_globals['_DECODEREQUEST_TOKENCONTENT']._serialized_start=303
_globals['_DECODEREQUEST_TOKENCONTENT']._serialized_end=336
_globals['_DECODERESPONSE']._serialized_start=356
_globals['_DECODERESPONSE']._serialized_end=687
_globals['_DECODERESPONSE_INITIALCONTENT']._serialized_start=522
_globals['_DECODERESPONSE_INITIALCONTENT']._serialized_end=538
_globals['_DECODERESPONSE_STREAMCONTENT']._serialized_start=541
_globals['_DECODERESPONSE_STREAMCONTENT']._serialized_end=670
_globals['_DECODERESPONSE_STREAMCONTENT_SAMPLE']._serialized_start=629
_globals['_DECODERESPONSE_STREAMCONTENT_SAMPLE']._serialized_end=670
_globals['_HEALTHCHECKREQUEST']._serialized_start=689
_globals['_HEALTHCHECKREQUEST']._serialized_end=709
_globals['_HEALTHCHECKRESPONSE']._serialized_start=711
_globals['_HEALTHCHECKRESPONSE']._serialized_end=749
_globals['_ORCHESTRATOR']._serialized_start=752
_globals['_ORCHESTRATOR']._serialized_end=937
# @@protoc_insertion_point(module_scope)
150 changes: 89 additions & 61 deletions jetstream/core/proto/jetstream_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,74 +21,102 @@


class OrchestratorStub(object):
"""TODO: Merge this with main JetStream core once we settle on an API."""
"""TODO: Merge this with main JetStream core once we settle on an API.

def __init__(self, channel):
"""Constructor.

Args:
channel: A grpc.Channel.
"""
self.Decode = channel.unary_stream(
"/jetstream_proto.Orchestrator/Decode",
request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.SerializeToString,
response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.FromString,
)

def __init__(self, channel):
"""Constructor.

Args:
channel: A grpc.Channel.
"""
self.Decode = channel.unary_stream(
'/jetstream_proto.Orchestrator/Decode',
request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.SerializeToString,
response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.FromString,
)
self.HealthCheck = channel.unary_unary(
'/jetstream_proto.Orchestrator/HealthCheck',
request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.SerializeToString,
response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.FromString,
)


class OrchestratorServicer(object):
"""TODO: Merge this with main JetStream core once we settle on an API."""
"""TODO: Merge this with main JetStream core once we settle on an API.

"""

def Decode(self, request, context):
"""Query LLM to generate text or tokens.
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def Decode(self, request, context):
"""Query LLM to generate text or tokens."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")
def HealthCheck(self, request, context):
"""Checks if the model server is live.
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_OrchestratorServicer_to_server(servicer, server):
rpc_method_handlers = {
"Decode": grpc.unary_stream_rpc_method_handler(
servicer.Decode,
request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.FromString,
response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
"jetstream_proto.Orchestrator", rpc_method_handlers
)
server.add_generic_rpc_handlers((generic_handler,))


# This class is part of an EXPERIMENTAL API.
rpc_method_handlers = {
'Decode': grpc.unary_stream_rpc_method_handler(
servicer.Decode,
request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.FromString,
response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.SerializeToString,
),
'HealthCheck': grpc.unary_unary_rpc_method_handler(
servicer.HealthCheck,
request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.FromString,
response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'jetstream_proto.Orchestrator', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))


# This class is part of an EXPERIMENTAL API.
class Orchestrator(object):
"""TODO: Merge this with main JetStream core once we settle on an API."""

@staticmethod
def Decode(
request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.unary_stream(
request,
target,
"/jetstream_proto.Orchestrator/Decode",
jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.SerializeToString,
jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)
"""TODO: Merge this with main JetStream core once we settle on an API.

"""

@staticmethod
def Decode(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_stream(request, target, '/jetstream_proto.Orchestrator/Decode',
jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.SerializeToString,
jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def HealthCheck(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/jetstream_proto.Orchestrator/HealthCheck',
jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.SerializeToString,
jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
6 changes: 6 additions & 0 deletions jetstream/tests/core/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ async def test_server(
) as channel:
stub = jetstream_pb2_grpc.OrchestratorStub(channel)

healthcheck_request = jetstream_pb2.HealthCheckRequest()
healthcheck_response = stub.HealthCheck(healthcheck_request)
healthcheck_response = await healthcheck_response

assert healthcheck_response.is_live == True

# The string representation of np.array([[65, 66]]), [2] will be prepended
# as BOS
text = "AB"
Expand Down
Loading