Skip to content

Commit

Permalink
add model and chunks to llama index callback handler (#876)
Browse files Browse the repository at this point in the history
  • Loading branch information
willydouhard authored Apr 4, 2024
1 parent 3f858cd commit c18a578
Showing 1 changed file with 38 additions and 6 deletions.
44 changes: 38 additions & 6 deletions backend/chainlit/llama_index/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
CBEventType.SYNTHESIZE,
CBEventType.EMBEDDING,
CBEventType.NODE_PARSING,
CBEventType.QUERY,
CBEventType.TREE,
]

Expand Down Expand Up @@ -71,9 +70,12 @@ def on_event_start(
) -> str:
"""Run when an event starts and return id of event."""
self._restore_context()

step_type: StepType = "undefined"
if event_type == CBEventType.RETRIEVE:
step_type = "retrieval"
elif event_type == CBEventType.QUERY:
step_type = "retrieval"
elif event_type == CBEventType.LLM:
step_type = "llm"
else:
Expand All @@ -84,7 +86,7 @@ def on_event_start(
type=step_type,
parent_id=self._get_parent_id(parent_id),
id=event_id,
disable_feedback=False,
disable_feedback=True,
)
self.steps[event_id] = step
step.start = utc_now()
Expand All @@ -102,17 +104,34 @@ def on_event_end(
"""Run when an event ends."""
step = self.steps.get(event_id, None)


if payload is None or step is None:
return

self._restore_context()

step.end = utc_now()

if event_type == CBEventType.RETRIEVE:
if event_type == CBEventType.QUERY:
response = payload.get(EventPayload.RESPONSE)
source_nodes = getattr(response, "source_nodes", None)
if source_nodes:
source_refs = ", ".join(
[f"Source {idx}" for idx, _ in enumerate(source_nodes)])
step.elements = [
Text(
name=f"Source {idx}",
content=source.text or "Empty node",
)
for idx, source in enumerate(source_nodes)
]
step.output = f"Retrieved the following sources: {source_refs}"
self.context.loop.create_task(step.update())

elif event_type == CBEventType.RETRIEVE:
sources = payload.get(EventPayload.NODES)
if sources:
source_refs = "\, ".join(
source_refs = ", ".join(
[f"Source {idx}" for idx, _ in enumerate(sources)]
)
step.elements = [
Expand All @@ -125,7 +144,7 @@ def on_event_end(
step.output = f"Retrieved the following sources: {source_refs}"
self.context.loop.create_task(step.update())

if event_type == CBEventType.LLM:
elif event_type == CBEventType.LLM:
formatted_messages = payload.get(
EventPayload.MESSAGES
) # type: Optional[List[ChatMessage]]
Expand All @@ -152,10 +171,15 @@ def on_event_end(
step.output = content

token_count = self.total_llm_token_count or None

raw_response = response.raw if response else None
model = raw_response.get("model", None) if raw_response else None
provider = "openai"

if messages and isinstance(response, ChatResponse):
msg: ChatMessage = response.message
step.generation = ChatGeneration(
provider=provider,
model=model,
messages=messages,
message_completion=GenerationMessage(
role=msg.role.value, # type: ignore
Expand All @@ -165,17 +189,25 @@ def on_event_end(
)
elif formatted_prompt:
step.generation = CompletionGeneration(
provider=provider,
model=model,
prompt=formatted_prompt,
completion=content,
token_count=token_count,
)

self.context.loop.create_task(step.update())

else:
step.output = payload.get
self.context.loop.create_task(step.update())
return

self.steps.pop(event_id, None)

def _noop(self, *args, **kwargs):
pass

start_trace = _noop
end_trace = _noop

0 comments on commit c18a578

Please sign in to comment.