Skip to content

Commit

Permalink
Add metadata field to agent messages (#1013)
Browse files Browse the repository at this point in the history
* add `metadata` field to agent messages

Currently only set when the default chat handler is used.

* include message metadata in TestLLMWithStreaming

* tweak TestLLMWithStreaming parameters

* pre-commit

* default to empty dict if no generation_info
  • Loading branch information
dlqqq authored Sep 25, 2024
1 parent 6e426ab commit 0884211
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 12 deletions.
9 changes: 5 additions & 4 deletions packages/jupyter-ai-test/jupyter_ai_test/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ def _stream(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
time.sleep(5)
time.sleep(1)
yield GenerationChunk(
text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 100.\n\n"
text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 5.\n\n",
generation_info={"test_metadata_field": "foobar"},
)
for i in range(1, 101):
time.sleep(0.5)
for i in range(1, 6):
time.sleep(0.2)
yield GenerationChunk(text=f"{i}, ")
6 changes: 6 additions & 0 deletions packages/jupyter-ai/jupyter_ai/callback_handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
Provides classes which extend `langchain_core.callbacks:BaseCallbackHandler`.
Not to be confused with Jupyter AI chat handlers.
"""

from .metadata import MetadataCallbackHandler
26 changes: 26 additions & 0 deletions packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult


class MetadataCallbackHandler(BaseCallbackHandler):
"""
When passed as a callback handler, this stores the LLMResult's
`generation_info` dictionary in the `self.jai_metadata` instance attribute
after the provider fully processes an input.
If used in a streaming chat handler: the `metadata` field of the final
`AgentStreamChunkMessage` should be set to `self.jai_metadata`.
If used in a non-streaming chat handler: the `metadata` field of the
returned `AgentChatMessage` should be set to `self.jai_metadata`.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.jai_metadata = {}

def on_llm_end(self, response: LLMResult, **kwargs) -> None:
if not (len(response.generations) and len(response.generations[0])):
return

self.jai_metadata = response.generations[0][0].generation_info or {}
25 changes: 19 additions & 6 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
import time
from typing import Dict, Type
from typing import Any, Dict, Type
from uuid import uuid4

from jupyter_ai.callback_handlers import MetadataCallbackHandler
from jupyter_ai.models import (
AgentStreamChunkMessage,
AgentStreamMessage,
Expand Down Expand Up @@ -85,13 +86,19 @@ def _start_stream(self, human_msg: HumanChatMessage) -> str:

return stream_id

def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = False):
def _send_stream_chunk(
self,
stream_id: str,
content: str,
complete: bool = False,
metadata: Dict[str, Any] = {},
):
"""
Sends an `agent-stream-chunk` message containing content that should be
appended to an existing `agent-stream` message with ID `stream_id`.
"""
stream_chunk_msg = AgentStreamChunkMessage(
id=stream_id, content=content, stream_complete=complete
id=stream_id, content=content, stream_complete=complete, metadata=metadata
)

for handler in self._root_chat_handlers.values():
Expand All @@ -104,6 +111,7 @@ def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = Fals
async def process_message(self, message: HumanChatMessage):
self.get_llm_chain()
received_first_chunk = False
assert self.llm_chain

inputs = {"input": message.body}
if "context" in self.prompt_template.input_variables:
Expand All @@ -121,10 +129,13 @@ async def process_message(self, message: HumanChatMessage):
# stream response in chunks. this works even if a provider does not
# implement streaming, as `astream()` defaults to yielding `_call()`
# when `_stream()` is not implemented on the LLM class.
assert self.llm_chain
metadata_handler = MetadataCallbackHandler()
async for chunk in self.llm_chain.astream(
inputs,
config={"configurable": {"last_human_msg": message}},
config={
"configurable": {"last_human_msg": message},
"callbacks": [metadata_handler],
},
):
if not received_first_chunk:
# when receiving the first chunk, close the pending message and
Expand All @@ -142,7 +153,9 @@ async def process_message(self, message: HumanChatMessage):
break

# complete stream after all chunks have been streamed
self._send_stream_chunk(stream_id, "", complete=True)
self._send_stream_chunk(
stream_id, "", complete=True, metadata=metadata_handler.jai_metadata
)

async def make_context_prompt(self, human_msg: HumanChatMessage) -> str:
return "\n\n".join(
Expand Down
18 changes: 17 additions & 1 deletion packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ class BaseAgentMessage(BaseModel):
this defaults to a description of `JupyternautPersona`.
"""

metadata: Dict[str, Any] = {}
"""
Message metadata set by a provider after fully processing an input. The
contents of this dictionary are provider-dependent, and can be any
dictionary with string keys. This field is not to be displayed directly to
the user, and is intended solely for developer purposes.
"""


class AgentChatMessage(BaseAgentMessage):
type: Literal["agent"] = "agent"
Expand All @@ -101,9 +109,17 @@ class AgentStreamMessage(BaseAgentMessage):
class AgentStreamChunkMessage(BaseModel):
type: Literal["agent-stream-chunk"] = "agent-stream-chunk"
id: str
"""ID of the parent `AgentStreamMessage`."""
content: str
"""The string to append to the `AgentStreamMessage` referenced by `id`."""
stream_complete: bool
"""Indicates whether this chunk message completes the referenced stream."""
"""Indicates whether this chunk completes the stream referenced by `id`."""
metadata: Dict[str, Any] = {}
"""
The metadata of the stream referenced by `id`. Metadata from the latest
chunk should override any metadata from previous chunks. See the docstring
on `BaseAgentMessage.metadata` for information.
"""


class HumanChatMessage(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai/src/chat_handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ export class ChatHandler implements IDisposable {
}

streamMessage.body += newMessage.content;
streamMessage.metadata = newMessage.metadata;
if (newMessage.stream_complete) {
streamMessage.complete = true;
}
Expand Down
4 changes: 4 additions & 0 deletions packages/jupyter-ai/src/components/chat-messages.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ function sortMessages(
export function ChatMessageHeader(props: ChatMessageHeaderProps): JSX.Element {
const collaborators = useCollaboratorsContext();

if (props.message.type === 'agent-stream' && props.message.complete) {
console.log(props.message.metadata);
}

const sharedStyles: SxProps<Theme> = {
height: '24px',
width: '24px'
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/src/components/pending-messages.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ export function PendingMessages(
time: lastMessage.time,
body: '',
reply_to: '',
persona: lastMessage.persona
persona: lastMessage.persona,
metadata: {}
});

// timestamp format copied from ChatMessage
Expand Down
2 changes: 2 additions & 0 deletions packages/jupyter-ai/src/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ export namespace AiService {
body: string;
reply_to: string;
persona: Persona;
metadata: Record<string, any>;
};

export type HumanChatMessage = {
Expand Down Expand Up @@ -172,6 +173,7 @@ export namespace AiService {
id: string;
content: string;
stream_complete: boolean;
metadata: Record<string, any>;
};

export type Request = ChatRequest | ClearRequest;
Expand Down

0 comments on commit 0884211

Please sign in to comment.