Skip to content

Commit

Permalink
episodeId and other changes + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkpiano committed Nov 29, 2024
1 parent 4561453 commit 4a79bac
Show file tree
Hide file tree
Showing 11 changed files with 374 additions and 81 deletions.
10 changes: 10 additions & 0 deletions .changeset/nice-pants-rule.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
'@statelyai/agent': major
---

- The `machine` and `machineHash` properties were removed from `AgentObservation` and `AgentObservationInput`
- The `defaultOptions` property was removed from `Agent`
- `AgentDecideOptions` was renamed to `AgentDecideInput`
- The `execute` property was removed from `AgentDecideInput`
- The `episodeId` optional property was added to `AgentDecideInput`, `AgentObservationInput`, and `AgentFeedbackInput`
- `decisionId` was added to `AgentObservationInput` and `AgentFeedbackInput`
175 changes: 175 additions & 0 deletions architecture.tldr
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
{
"tldrawFileFormatVersion": 1,
"schema": {
"schemaVersion": 2,
"sequences": {
"com.tldraw.store": 4,
"com.tldraw.asset": 1,
"com.tldraw.camera": 1,
"com.tldraw.document": 2,
"com.tldraw.instance": 25,
"com.tldraw.instance_page_state": 5,
"com.tldraw.page": 1,
"com.tldraw.instance_presence": 5,
"com.tldraw.pointer": 1,
"com.tldraw.shape": 4,
"com.tldraw.asset.bookmark": 2,
"com.tldraw.asset.image": 5,
"com.tldraw.asset.video": 5,
"com.tldraw.shape.group": 0,
"com.tldraw.shape.text": 2,
"com.tldraw.shape.bookmark": 2,
"com.tldraw.shape.draw": 2,
"com.tldraw.shape.geo": 9,
"com.tldraw.shape.note": 8,
"com.tldraw.shape.line": 5,
"com.tldraw.shape.frame": 0,
"com.tldraw.shape.arrow": 5,
"com.tldraw.shape.highlight": 1,
"com.tldraw.shape.embed": 4,
"com.tldraw.shape.image": 4,
"com.tldraw.shape.video": 2,
"com.tldraw.binding.arrow": 0
}
},
"records": [
{
"gridSize": 10,
"name": "",
"meta": {},
"id": "document:document",
"typeName": "document"
},
{
"meta": {},
"id": "page:page",
"name": "Page 1",
"index": "a1",
"typeName": "page"
},
{
"id": "pointer:pointer",
"typeName": "pointer",
"x": 369.44140625,
"y": 256.55078125,
"lastActivityTimestamp": 1732900622027,
"meta": {}
},
{
"followingUserId": null,
"opacityForNextShape": 1,
"stylesForNextShape": {
"tldraw:size": "s"
},
"brush": null,
"scribbles": [],
"cursor": {
"type": "default",
"rotation": 0
},
"isFocusMode": false,
"exportBackground": true,
"isDebugMode": false,
"isToolLocked": false,
"screenBounds": {
"x": 0,
"y": 0,
"w": 1361,
"h": 684
},
"insets": [
false,
false,
true,
false
],
"zoomBrush": null,
"isGridMode": false,
"isPenMode": false,
"chatMessage": "",
"isChatting": false,
"highlightedUserIds": [],
"isFocused": true,
"devicePixelRatio": 2,
"isCoarsePointer": false,
"isHoveringCanvas": true,
"openMenus": [],
"isChangingStyle": false,
"isReadonly": false,
"meta": {},
"duplicateProps": null,
"id": "instance:instance",
"currentPageId": "page:page",
"typeName": "instance"
},
{
"editingShapeId": null,
"croppingShapeId": null,
"selectedShapeIds": [
"shape:iFECnSIn2r2wJYVAgyVIm"
],
"hoveredShapeId": "shape:iFECnSIn2r2wJYVAgyVIm",
"erasingShapeIds": [],
"hintingShapeIds": [],
"focusedGroupId": null,
"meta": {},
"id": "instance_page_state:page:page",
"pageId": "page:page",
"typeName": "instance_page_state"
},
{
"x": 0,
"y": 0,
"z": 1,
"meta": {},
"id": "camera:page:page",
"typeName": "camera"
},
{
"x": 326.71875,
"y": 152.83984375,
"rotation": 0,
"isLocked": false,
"opacity": 1,
"meta": {},
"id": "shape:Az6bLzi8VaAX36Icv2AmG",
"type": "text",
"props": {
"color": "black",
"size": "m",
"w": 81.3828125,
"text": "decide",
"font": "draw",
"textAlign": "start",
"autoSize": true,
"scale": 1
},
"parentId": "page:page",
"index": "a1",
"typeName": "shape"
},
{
"x": 338.31640625,
"y": 193.2578125,
"rotation": 0,
"isLocked": false,
"opacity": 1,
"meta": {},
"id": "shape:iFECnSIn2r2wJYVAgyVIm",
"type": "text",
"props": {
"color": "black",
"size": "s",
"w": 93.2890625,
"text": "episodeId\nstate\ngoal\ncontext",
"font": "draw",
"textAlign": "start",
"autoSize": true,
"scale": 1
},
"parentId": "page:page",
"index": "a290f",
"typeName": "shape"
}
]
}
4 changes: 2 additions & 2 deletions examples/learn-from-feedback.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ Achieve the goal. Consider both exploring unknown actions (high exploration_valu

if (decision?.nextEvent?.type === 'submit') {
const observation = await agent.addObservation({
goal: decision.goal,
decisionId: decision.id,
prevState: { value: 'editing' },
event: { type: 'submit' },
state: { value: 'editing' },
Expand All @@ -96,7 +96,7 @@ Achieve the goal. Consider both exploring unknown actions (high exploration_valu
status = 'submitted';

await agent.addObservation({
goal: decision.goal,
decisionId: decision.id,
prevState: { value: 'editing' },
event: { type: 'pressEnter' },
state: { value: 'submitted' },
Expand Down
114 changes: 106 additions & 8 deletions src/agent.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { test, expect, vi } from 'vitest';
import { createAgent, TypesFromAgent } from './';
import { AgentDecision, createAgent, TypesFromAgent } from './';
import { createActor, createMachine } from 'xstate';
import { LanguageModelV1CallOptions } from 'ai';
import { z } from 'zod';
Expand Down Expand Up @@ -66,15 +66,30 @@ test('agent.addMessage() adds to message history', () => {
test('agent.addFeedback() adds to feedback', () => {
const agent = createAgent({
id: 'test',
events: {},
events: {
play: z.object({
position: z.number(),
}),
},
model: {} as any,
});

const decision: AgentDecision<typeof agent> = {
goal: 'Win the game',
episodeId: agent.episodeId,
goalState: { value: 'won' },
id: 'decision-1',
nextEvent: { type: 'play', position: 3 },
paths: [],
strategy: 'simple',
timestamp: Date.now(),
};

const obs = agent.addObservation({
decisionId: decision.id,
prevState: { value: 'playing' },
state: { value: 'lost' },
event: { type: 'play', position: 3 },
goal: 'Win the game',
state: { value: 'lost' },
});

const feedback = agent.addFeedback({
Expand Down Expand Up @@ -167,7 +182,6 @@ test('agent.addObservation() adds to observations with machine hash', () => {
prevState: { value: 'playing', context: {} },
event: { type: 'play', position: 3 },
state: { value: 'lost', context: {} },
machine,
goal: 'Win the game',
});

Expand All @@ -178,7 +192,6 @@ test('agent.addObservation() adds to observations with machine hash', () => {
prevState: { value: 'playing', context: {} },
event: { type: 'play', position: 3 },
state: { value: 'lost', context: {} },
machineHash: expect.any(String),
episodeId: expect.any(String),
timestamp: expect.any(Number),
})
Expand Down Expand Up @@ -503,7 +516,6 @@ test('agent.observe() adds observations from actor snapshots', () => {
expect(agent.getObservations()).toContainEqual(
expect.objectContaining({
state: expect.objectContaining({ value: 'idle' }),
machineHash: expect.any(String),
})
);

Expand All @@ -512,9 +524,95 @@ test('agent.observe() adds observations from actor snapshots', () => {
prevState: expect.objectContaining({ value: 'idle' }),
event: { type: 'START' },
state: expect.objectContaining({ value: 'running' }),
machineHash: expect.any(String),
})
);

subscription.unsubscribe();
});

test('agent.addObservation() accepts custom episodeId', () => {
const agent = createAgent({
id: 'test',
events: {},
model: {} as any,
});

const customEpisodeId = 'custom-episode-123';
const observation = agent.addObservation({
state: { value: 'playing' },
goal: 'Win the game',
episodeId: customEpisodeId,
});

expect(observation.episodeId).toEqual(customEpisodeId);
expect(agent.getObservations()).toContainEqual(
expect.objectContaining({
episodeId: customEpisodeId,
})
);
});

test('agent.addFeedback() accepts custom episodeId', () => {
const agent = createAgent({
id: 'test',
events: {},
model: {} as any,
});

const customEpisodeId = 'custom-episode-123';
const feedback = agent.addFeedback({
score: 1,
observationId: 'obs-1',
episodeId: customEpisodeId,
});

expect(feedback.episodeId).toEqual(customEpisodeId);
expect(agent.getFeedback()).toContainEqual(
expect.objectContaining({
episodeId: customEpisodeId,
})
);
});

test('agent.addObservation() accepts decisionId', () => {
const agent = createAgent({
id: 'test',
events: {},
model: {} as any,
});

const decisionId = 'decision-123';
const observation = agent.addObservation({
state: { value: 'playing' },
goal: 'Win the game',
decisionId,
});

expect(observation.decisionId).toEqual(decisionId);
expect(agent.getObservations()).toContainEqual(
expect.objectContaining({
decisionId,
})
);
});

test('agent.addFeedback() accepts decisionId', () => {
const agent = createAgent({
id: 'test',
events: {},
model: {} as any,
});

const decisionId = 'decision-123';
const feedback = agent.addFeedback({
score: 1,
decisionId,
});

expect(feedback.decisionId).toEqual(decisionId);
expect(agent.getFeedback()).toContainEqual(
expect.objectContaining({
decisionId,
})
);
});
Loading

0 comments on commit 4a79bac

Please sign in to comment.