Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State persistence #955

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions docs/graph.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,9 @@ class Increment(BaseNode): # (2)!


fives_graph = Graph(nodes=[DivisibleBy5, Increment]) # (3)!
result, history = fives_graph.run_sync(DivisibleBy5(4)) # (4)!
result = fives_graph.run_sync(DivisibleBy5(4)) # (4)!
print(result)
#> 5
# the full history is quite verbose (see below), so we'll just print the summary
print([item.data_snapshot() for item in history])
#> [DivisibleBy5(foo=4), Increment(foo=4), DivisibleBy5(foo=5), End(data=5)]
```

1. The `DivisibleBy5` node is parameterized with `None` for the state param and `None` for the deps param as this graph doesn't use state or deps, and `int` as it can end the run.
Expand Down Expand Up @@ -464,7 +461,7 @@ async def main():
)
state = State(user)
feedback_graph = Graph(nodes=(WriteEmail, Feedback))
email, _ = await feedback_graph.run(WriteEmail(), state=state)
email = await feedback_graph.run(WriteEmail(), state=state)
print(email)
"""
Email(
Expand Down Expand Up @@ -576,27 +573,28 @@ In this example, an AI asks the user a question, the user provides an answer, th

_(This example is complete, it can be run "as is" with Python 3.10+)_


```python {title="ai_q_and_a_run.py" noqa="I001" py="3.10"}
from rich.prompt import Prompt

from pydantic_graph import End, HistoryStep
from pydantic_graph import End, FullStatePersistence

from ai_q_and_a_graph import Ask, question_graph, QuestionState, Answer


async def main():
state = QuestionState() # (1)!
node = Ask() # (2)!
history: list[HistoryStep[QuestionState]] = [] # (3)!
persistence = FullStatePersistence() # (3)!
while True:
node = await question_graph.next(node, history, state=state) # (4)!
node = await question_graph.next( # (4)!
node, persistence=persistence, state=state
)
if isinstance(node, Answer):
node.answer = Prompt.ask(node.question) # (5)!
elif isinstance(node, End): # (6)!
print(f'Correct answer! {node.data}')
#> Correct answer! Well done, 1 + 1 = 2
print([e.data_snapshot() for e in history])
print([e.node for e in persistence.history])
"""
[
Ask(),
Expand All @@ -614,7 +612,7 @@ async def main():

1. Create the state object which will be mutated by [`next`][pydantic_graph.graph.Graph.next].
2. The start node is `Ask` but will be updated by [`next`][pydantic_graph.graph.Graph.next] as the graph runs.
3. The history of the graph run is stored in a list of [`HistoryStep`][pydantic_graph.state.HistoryStep] objects. Again [`next`][pydantic_graph.graph.Graph.next] will update this list in place.
3. The history of the graph run is stored using [`FullStatePersistence`][pydantic_graph.state.memory.FullStatePersistence]. Again [`next`][pydantic_graph.graph.Graph.next] will update this list in place.
4. [Run][pydantic_graph.graph.Graph.next] the graph one node at a time, updating the state, current node and history as the graph runs.
5. If the current node is an `Answer` node, prompt the user for an answer.
6. Since we're using [`next`][pydantic_graph.graph.Graph.next] we have to manually check for an [`End`][pydantic_graph.nodes.End] and exit the loop if we get one.
Expand Down
37 changes: 22 additions & 15 deletions examples/pydantic_ai_examples/question_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@

import logfire
from devtools import debug
from pydantic_graph import BaseNode, Edge, End, Graph, GraphRunContext, HistoryStep
from pydantic_graph import (
BaseNode,
Edge,
End,
FullStatePersistence,
Graph,
GraphRunContext,
)

from pydantic_ai import Agent
from pydantic_ai.format_as_xml import format_as_xml
Expand Down Expand Up @@ -116,12 +123,12 @@ async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask:
async def run_as_continuous():
state = QuestionState()
node = Ask()
history: list[HistoryStep[QuestionState, None]] = []
persistence = FullStatePersistence()
with logfire.span('run questions graph'):
while True:
node = await question_graph.next(node, history, state=state)
node = await question_graph.next(node, persistence=persistence, state=state)
if isinstance(node, End):
debug([e.data_snapshot() for e in history])
debug([e.node for e in persistence.history])
break
elif isinstance(node, Answer):
assert state.question
Expand All @@ -131,14 +138,14 @@ async def run_as_continuous():

async def run_as_cli(answer: str | None):
history_file = Path('question_graph_history.json')
history = (
question_graph.load_history(history_file.read_bytes())
if history_file.exists()
else []
)

if history:
last = history[-1]
persistence = FullStatePersistence()
question_graph.set_persistence_types(persistence)

if history_file.exists():
persistence.load_json(history_file.read_bytes())

if persistence.history:
last = persistence.history[-1]
assert last.kind == 'node', 'expected last step to be a node'
state = last.state
assert answer is not None, 'answer is required to continue from history'
Expand All @@ -150,17 +157,17 @@ async def run_as_cli(answer: str | None):

with logfire.span('run questions graph'):
while True:
node = await question_graph.next(node, history, state=state)
node = await question_graph.next(node, persistence=persistence, state=state)
if isinstance(node, End):
debug([e.data_snapshot() for e in history])
debug([e.node for e in persistence.history])
print('Finished!')
break
elif isinstance(node, Answer):
print(state.question)
break
# otherwise just continue

history_file.write_bytes(question_graph.dump_history(history, indent=2))
history_file.write_bytes(persistence.dump_json(indent=2))


if __name__ == '__main__':
Expand Down
6 changes: 2 additions & 4 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import logfire_api
from typing_extensions import TypeVar, deprecated

from pydantic_graph import Graph, GraphRunContext, HistoryStep
from pydantic_graph import Graph, GraphRunContext
from pydantic_graph.nodes import End

from . import (
Expand Down Expand Up @@ -337,7 +337,7 @@ async def main():
)

# Actually run
end_result, _ = await graph.run(
end_result = await graph.run(
start_node,
state=state,
deps=graph_deps,
Expand Down Expand Up @@ -583,7 +583,6 @@ async def main():

# Actually run
node = start_node
history: list[HistoryStep[_agent_graph.GraphAgentState, RunResultDataT]] = []
while True:
if isinstance(node, _agent_graph.StreamModelRequestNode):
node = cast(
Expand All @@ -599,7 +598,6 @@ async def main():
assert not isinstance(node, End) # the previous line should be hit first
node = await graph.next(
node,
history,
state=graph_state,
deps=graph_deps,
infer_name=False,
Expand Down
5 changes: 1 addition & 4 deletions pydantic_graph/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,7 @@ class Increment(BaseNode):


fives_graph = Graph(nodes=[DivisibleBy5, Increment])
result, history = fives_graph.run_sync(DivisibleBy5(4))
result = fives_graph.run_sync(DivisibleBy5(4))
print(result)
#> 5
# the full history is quite verbose (see below), so we'll just print the summary
print([item.data_snapshot() for item in history])
#> [DivisibleBy5(foo=4), Increment(foo=4), DivisibleBy5(foo=5), End(data=5)]
```
11 changes: 7 additions & 4 deletions pydantic_graph/pydantic_graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from .exceptions import GraphRuntimeError, GraphSetupError
from .graph import Graph
from .nodes import BaseNode, Edge, End, GraphRunContext
from .state import EndStep, HistoryStep, NodeStep
from .state import EndSnapshot, NodeSnapshot, Snapshot
from .state.memory import FullStatePersistence, SimpleStatePersistence

__all__ = (
'Graph',
'BaseNode',
'End',
'GraphRunContext',
'Edge',
'EndStep',
'HistoryStep',
'NodeStep',
'EndSnapshot',
'Snapshot',
'NodeSnapshot',
'GraphSetupError',
'GraphRuntimeError',
'SimpleStatePersistence',
'FullStatePersistence',
)
5 changes: 0 additions & 5 deletions pydantic_graph/pydantic_graph/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import sys
import types
from datetime import datetime, timezone
from typing import Annotated, Any, TypeVar, Union, get_args, get_origin

import typing_extensions
Expand Down Expand Up @@ -80,10 +79,6 @@ def get_parent_namespace(frame: types.FrameType | None) -> dict[str, Any] | None
return back.f_locals


def now_utc() -> datetime:
return datetime.now(tz=timezone.utc)


class Unset:
"""A singleton to represent an unset value.

Expand Down
Loading