Skip to content

Commit

Permalink
Remove goal from feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkpiano committed Nov 24, 2024
1 parent 57070f8 commit 9d65d71
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 32 deletions.
5 changes: 5 additions & 0 deletions .changeset/swift-mangos-rush.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@statelyai/agent": patch
---

Remove `goal` from feedback input
3 changes: 2 additions & 1 deletion examples/learn-from-feedback.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ Achieve the goal. Consider both exploring unknown actions (high exploration_valu

if (decision?.nextEvent?.type === 'submit') {
const observation = await agent.addObservation({
goal: decision.goal,
prevState: { value: 'editing' },
event: { type: 'submit' },
state: { value: 'editing' },
Expand All @@ -88,14 +89,14 @@ Achieve the goal. Consider both exploring unknown actions (high exploration_valu
// don't change the status; pretend submit button is broken
await agent.addFeedback({
observationId: observation.id,
goal: 'Submit the form',
score: 0,
comment: 'Form not submitted',
});
} else if (decision?.nextEvent?.type === 'pressEnter') {
status = 'submitted';

await agent.addObservation({
goal: decision.goal,
prevState: { value: 'editing' },
event: { type: 'pressEnter' },
state: { value: 'submitted' },
Expand Down
41 changes: 19 additions & 22 deletions src/agent.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,28 +70,24 @@ test('agent.addFeedback() adds to feedback', () => {
model: {} as any,
});

const feedback = agent.addFeedback({
score: -1,
const obs = agent.addObservation({
prevState: { value: 'playing' },
state: { value: 'lost' },
event: { type: 'play', position: 3 },
goal: 'Win the game',
observationId: 'obs-1',
});

const feedback = agent.addFeedback({
score: 0,
observationId: obs.id,
});

expect(feedback.episodeId).toEqual(agent.episodeId);

expect(agent.getFeedback()).toContainEqual(
expect.objectContaining({
score: -1,
goal: 'Win the game',
observationId: 'obs-1',
episodeId: expect.any(String),
timestamp: expect.any(Number),
})
);
expect(agent.getFeedback()).toContainEqual(
expect.objectContaining({
score: -1,
goal: 'Win the game',
observationId: 'obs-1',
score: 0,
observationId: obs.id,
episodeId: expect.any(String),
timestamp: expect.any(Number),
})
Expand All @@ -109,6 +105,7 @@ test('agent.addObservation() adds to observations', () => {
prevState: { value: 'playing', context: {} },
event: { type: 'play', position: 3 },
state: { value: 'lost', context: {} },
goal: 'Win the game',
});

expect(observation.episodeId).toEqual(agent.episodeId);
Expand All @@ -133,6 +130,7 @@ test('agent.addObservation() adds to observations (initial state)', () => {

const observation = agent.addObservation({
state: { value: 'lost' },
goal: 'Win the game',
});

expect(observation.episodeId).toEqual(agent.episodeId);
Expand Down Expand Up @@ -170,6 +168,7 @@ test('agent.addObservation() adds to observations with machine hash', () => {
event: { type: 'play', position: 3 },
state: { value: 'lost', context: {} },
machine,
goal: 'Win the game',
});

expect(observation.episodeId).toEqual(agent.episodeId);
Expand Down Expand Up @@ -197,29 +196,27 @@ test('agent.addFeedback() adds to feedback (with observation)', () => {
state: {
value: 'playing',
},
goal: 'Win the game',
});

const feedback = agent.addFeedback({
score: -1,
goal: 'Win the game',
score: 0,
observationId: observation.id,
});

expect(feedback.episodeId).toEqual(agent.episodeId);

expect(agent.getFeedback()).toContainEqual(
expect.objectContaining({
score: -1,
goal: 'Win the game',
score: 0,
observationId: observation.id,
episodeId: expect.any(String),
timestamp: expect.any(Number),
})
);
expect(agent.getFeedback()).toContainEqual(
expect.objectContaining({
score: -1,
goal: 'Win the game',
score: 0,
observationId: observation.id,
episodeId: expect.any(String),
timestamp: expect.any(Number),
Expand Down Expand Up @@ -286,7 +283,6 @@ test('You can listen for feedback events', () => {

agent.addFeedback({
score: -1,
goal: 'Win the game',
observationId: 'obs-1',
});

Expand Down Expand Up @@ -431,6 +427,7 @@ test('agent.getDecisions() returns decisions from context', () => {
model: {} as any,
strategy: async (agent) => {
return {
id: Date.now().toString(),
episodeId: agent.episodeId,
strategy: 'test-strategy',
goal: '',
Expand Down
32 changes: 26 additions & 6 deletions src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -391,17 +391,19 @@ export class Agent<
) {
const observation = agent.addObservation(observationInput);

const input = getInput?.(observation);
const interactInput = getInput?.(observation);

if (input) {
const res = await agentDecide(agent, {
if (interactInput) {
const decision = await agentDecide(agent, {
machine,
state: observation.state,
...input,
...interactInput,
});

if (res?.nextEvent) {
actorRef.send(res.nextEvent);
if (decision?.nextEvent) {
// @ts-ignore
decision.nextEvent['_decision'] = decision.id;
actorRef.send(decision.nextEvent);
}
}

Expand All @@ -420,11 +422,20 @@ export class Agent<
return;
}

const decisionId = inspEvent.event['_decision'] as
| string
| undefined;

const decision = decisionId
? agent.getDecisions().find((d) => d.id === decisionId)
: undefined;

const observationInput = {
event: inspEvent.event,
prevState,
state: inspEvent.snapshot as any,
machine: (actorRef as any).src,
goal: decision?.goal,
} satisfies AgentObservationInput<any>;

await handleObservation(observationInput);
Expand All @@ -439,6 +450,7 @@ export class Agent<
event: undefined,
state: actorRef.getSnapshot(),
machine: (actorRef as any).src,
goal: undefined,
});
}

Expand All @@ -464,11 +476,19 @@ export class Agent<
return;
}

const decisionId = inspEvent.event['_decision'] as
| string
| undefined;
const decision = decisionId
? this.getDecisions().find((d) => d.id === decisionId)
: undefined;

const observationInput = {
event: inspEvent.event,
prevState,
state: inspEvent.snapshot as any,
machine: (actorRef as any).src,
goal: decision?.goal,
} satisfies AgentObservationInput<this>;

prevState = observationInput.state;
Expand Down
2 changes: 2 additions & 0 deletions src/strategies/shortestPath.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { z } from 'zod';
import { zodToJsonSchema } from 'zod-to-json-schema';
import Ajv from 'ajv';
import { AnyMachineSnapshot } from 'xstate';
import { randomId } from '../utils';

const ajv = new Ajv();

Expand Down Expand Up @@ -165,6 +166,7 @@ Examples:
const nextStep = leastWeightPath?.steps[0];

return {
id: randomId(),
strategy: 'shortestPath',
episodeId: agent.episodeId,
goal: input.goal,
Expand Down
1 change: 1 addition & 0 deletions src/strategies/simple.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ export async function simpleStrategy<T extends AnyAgent>(
}

return {
id: randomId(),
strategy: 'simple',
goal: input.goal,
goalState: input.state,
Expand Down
6 changes: 3 additions & 3 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ export type AgentPath<TAgent extends AnyAgent> = {
};

export type AgentDecision<TAgent extends AnyAgent> = {
id: string;
/**
* The strategy used to generate the decision
*/
Expand Down Expand Up @@ -170,7 +171,6 @@ export type AgentDecideOptions<TAgent extends AnyAgent> = {
} & Omit<Parameters<typeof generateText>[0], 'model' | 'tools' | 'prompt'>;

export interface AgentFeedback {
goal: string;
observationId: string;
score: number;
comment: string | undefined;
Expand All @@ -183,7 +183,6 @@ export interface AgentFeedback {
}

export interface AgentFeedbackInput {
goal: string;
observationId: string;
score: number;
comment?: string;
Expand Down Expand Up @@ -331,7 +330,7 @@ export type AgentMessageInput = CoreMessage & {

export interface AgentObservation<TActor extends ActorRefLike> {
id: string;
// TODO: goal
goal?: string;
prevState: SnapshotFrom<TActor> | undefined;
event: EventFrom<TActor> | undefined;
state: SnapshotFrom<TActor>;
Expand All @@ -347,6 +346,7 @@ export interface AgentObservationInput<TAgent extends AnyAgent> {
state: ObservedState<TAgent>;
machine?: AnyStateMachine;
timestamp?: number;
goal: string | undefined;
}

export type AgentDecisionInput = {
Expand Down

0 comments on commit 9d65d71

Please sign in to comment.