Skip to content

Commit

Permalink
feat: integrated agent, live evaluations
Browse files Browse the repository at this point in the history
+ feat: evaluations run against live app
+ feat: deployed LLM agent to Functions
+ feat: integrated agent into app
  • Loading branch information
telpirion authored Nov 27, 2024
2 parents e6c1555 + bfd477b commit f6e8015
Show file tree
Hide file tree
Showing 18 changed files with 451 additions and 117 deletions.
31 changes: 26 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ This system allows the usage of three related LLM models:

+ The out-of-the-box [Gemini 1.5 Flash model][gemini]
+ A tuned version of the Gemini 1.5 Flash model, trained on the [Guanaco dataset][guanaco].
+ A [Gemma 2][gemma2] open source model. This model currently cannot be
evaluated with the Evaluations API.
+ A [Gemma 2][gemma2] open source model.

These models have been evaluated against the following set of metrics.

Expand All @@ -55,10 +54,29 @@ The following table shows the evaluation scores for each of these models.

| Model | ROUGE | Closed domain | Open domain | Groundedness | Coherence | Date of eval |
| ---------------- | ------ | ------------- | ----------- | ------------ | --------- | ------------ |
| Gemini 1.5 Flash | 1.0[1] | 0.52 | 1.0 | 1.0[1] | 3.8 | 2024-11-07 |
| Tuned Gemini | 0.41 | 0.8 | 1.0 | 0.6 | 3.8 | 2024-11-07 |
| Gemini 1.5 Flash | 0.20[1]| 0.0 | 1.0 | 1.0[1] | 3.3 | 2024-11-25 |
| Tuned Gemini | 0.21 | 0.4 | 1.0 | 1.0 | 2.4 | 2024-11-25 |
| Gemma | 0.05 | 0.6 | 0.4 | 0.8 | 1.4 | 2024-11-25 |

[1]: Gemini 1.5 Flash responses were used as the ground truth for all other models.
[1]: Gemini 1.5 Flash responses from 2024-11-05 are used as the ground truth
for all other models.

## Adversarial evaluations

These models have been evaluated against the following set of adversarial
techniques.

+ [Prompt injection][injection]
+ [Prompt leaking][leaking]
+ [Jailbreaking][jailbreaking]

The following table shows the evaluation scores for adversarial prompting.

| Model | Prompt injection | Prompt leaking | Jailbreaking | Date of eval |
| ---------------- | ----------------- | -------------- | ------------ | ------------ |
| Gemini 1.5 Flash | 0.66 | 0.66 | 1.0 | 2024-11-25 |
| Tuned Gemini | 0.33 | 1.0 | 1.0 | 2024-11-25 |
| Gemma | 1.0 | 0.66 | 0.66 | 2024-11-25 |

[bigquery]: https://cloud.google.com/bigquery/docs
[bulma]: https://bulma.io/documentation/components/message/
Expand All @@ -75,6 +93,9 @@ The following table shows the evaluation scores for each of these models.
[groundedness]: https://cloud.google.com/vertex-ai/generative-ai/docs/models/metrics-templates#pointwise_groundedness
[guanaco]: https://huggingface.co/datasets/timdettmers/openassistant-guanaco
[herodotus]: https://en.wikipedia.org/wiki/Herodotus
[injection]: https://www.promptingguide.ai/prompts/adversarial-prompting/prompt-injection
[jailbreaking]: https://www.promptingguide.ai/prompts/adversarial-prompting/jailbreaking-llms
[leaking]: https://www.promptingguide.ai/prompts/adversarial-prompting/prompt-leaking
[pytorch]: https://pytorch.org/
[rouge]: https://cloud.google.com/vertex-ai/generative-ai/docs/models/determine-eval#rouge
[run]: https://cloud.google.com/run/docs/overview/what-is-cloud-run
Expand Down
55 changes: 49 additions & 6 deletions docs/services.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,57 @@ $ gcloud run jobs execute embeddings --region us-west1
## Reddit tool / agent

The [Reddit tool](../services/reddit-tool/) allows the LLM to read [r/travel][subreddit] posts based
upon a user query. The tool is packaged as a Vertex AI [Reasoning Engine agent][reasoning]. Internally,
the tool uses [LangChain][langchain] along with the Vertex AI Python SDK to perform its
magic.
upon a user query. The tool is packaged as a Vertex AI [Reasoning Engine agent][reasoning].
Internally, the tool uses [LangChain][langchain] along with the Vertex AI Python
SDK to perform its magic.

### Deploy the agent
**WARNING**: As of writing (2024-11-26), the Vertex AI Reasoning Engine agent
doesn't work as intended. Instead, the agent is published to Cloud Functions.
**NOTE**: You might need to install `pyenv` first before completing these instructions.
See [Troubleshooting](./troubleshooting.md) for more details.
### Test the agent locally (Cloud Functions)
1. Run the Cloud Function locally.
```sh
functions-framework-python --target get_agent_request
```
1. Send a request to the app with `curl`.
```sh
curl --header "Content-Type: application/json" \
--request POST \
--data '{"query":"I want to go to Crete. Where should I stay?"}' \
http://localhost:8080
```
Deployed location:
https://reddit-tool-1025771077852.us-west1.run.app
### Deploy the agent (Cloud Functions)
Run the following from the root of the reddit-tool directory.
```sh
gcloud functions deploy reddit-tool \
--gen2 \
--memory=512MB \
--timeout=120s \
--runtime=python312 \
--region=us-west1 \
--set-env-vars PROJECT_ID=${PROJECT_ID},BUCKET=${BUCKET} \
--source=. \
--entry-point=get_agent_request \
--trigger-http \
--allow-unauthenticated
```
### Deploy the agent (Reasoning Engine)
**NOTES**:
+ You might need to install `pyenv` first before completing these instructions.
See [Troubleshooting](./troubleshooting.md) for more details.
+
1. Create a virtual environment. The virtual environment needs to have Python v3.6 <= x <= v3.11.
Expand Down
49 changes: 49 additions & 0 deletions server/ai/reddit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package ai

import (
"context"
"fmt"

"github.com/vartanbeno/go-reddit/v2/reddit"
)

const subredditName = "travel"

func getRedditPosts(location string) (string, error) {
client, err := reddit.NewReadonlyClient()
if err != nil {
return "", err
}

ctx := context.Background()
posts, _, err := client.Subreddit.SearchPosts(ctx, location, subredditName, &reddit.ListPostSearchOptions{
ListPostOptions: reddit.ListPostOptions{
ListOptions: reddit.ListOptions{
Limit: 5,
},
Time: "all",
},
})
if err != nil {
return "", err
}

response := ""

for _, post := range posts {
if post.Body != "" {

postAndComments, _, err := client.Post.Get(ctx, post.ID)
if err != nil {
response += fmt.Sprintf("Title: %s, Post: %s",
post.Title, post.Body)
continue
}

response += fmt.Sprintf("Title: %s, Post: %s, Top Comment:\n",
post.Title, post.Body, postAndComments.Comments[0])
}
}

return response, nil
}
148 changes: 110 additions & 38 deletions server/ai/vertex.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ func Predict(query, modality, projectID string) (response string, templateName s
response, err = textPredictGemma(query, projectID)
case GeminiTuned:
response, err = textPredictGemini(query, projectID, GeminiTuned)
case AgentAssisted:
response, err = textPredictWithReddit(query, projectID)
default:
response, err = textPredictGemini(query, projectID, Gemini)
}
Expand Down Expand Up @@ -125,6 +127,51 @@ func SetConversationContext(convoHistory []generated.ConversationBit) error {
return nil
}

// storeConversationContext uploads past user conversations with the model into a Gen AI context.
// This context is used when the model is answering questions from the user.
func StoreConversationContext(conversationHistory []generated.ConversationBit, projectID string) (string, error) {
if len(conversationHistory) < MinimumConversationNum {
return "", &MinCacheNotReachedError{ConversationCount: len(conversationHistory)}
}

ctx := context.Background()
location := "us-west1"
client, err := genai.NewClient(ctx, projectID, location)
if err != nil {
return "", fmt.Errorf("unable to create client: %w", err)
}
defer client.Close()

var userParts []genai.Part
var modelParts []genai.Part
for _, p := range conversationHistory {
userParts = append(userParts, genai.Text(p.UserQuery))
modelParts = append(modelParts, genai.Text(p.BotResponse))
}

content := &genai.CachedContent{
Model: GeminiModel,
Expiration: genai.ExpireTimeOrTTL{TTL: 60 * time.Minute},
Contents: []*genai.Content{
{
Role: "user",
Parts: userParts,
},
{
Role: "model",
Parts: modelParts,
},
},
}
result, err := client.CreateCachedContent(ctx, content)
if err != nil {
return "", fmt.Errorf("CreateCachedContent: %w", err)
}
resourceName := result.Name

return resourceName, nil
}

// extractAnswer cleans up the response returned from the models
func extractAnswer(response string) string {
// I am not a regex expert :/
Expand Down Expand Up @@ -271,57 +318,82 @@ func getCandidate(resp *genai.GenerateContentResponse) (string, error) {
return string(candidate), nil
}

// storeConversationContext uploads past user conversations with the model into a Gen AI context.
// This context is used when the model is answering questions from the user.
func StoreConversationContext(conversationHistory []generated.ConversationBit, projectID string) (string, error) {
if len(conversationHistory) < MinimumConversationNum {
return "", &MinCacheNotReachedError{ConversationCount: len(conversationHistory)}
func trimContext() (last string) {
sep := "###"
convos := strings.Split(cachedContext, sep)
length := len(convos)
if len(convos) > 3 {
last = strings.Join(convos[length-3:length-1], sep)
}
return last
}

func textPredictWithReddit(query, projectID string) (string, error) {
funcName := "GetRedditPosts"
ctx := context.Background()
location := "us-west1"
client, err := genai.NewClient(ctx, projectID, location)
client, err := genai.NewClient(ctx, projectID, "us-west1")
if err != nil {
return "", fmt.Errorf("unable to create client: %w", err)
return "", err
}
defer client.Close()

var userParts []genai.Part
var modelParts []genai.Part
for _, p := range conversationHistory {
userParts = append(userParts, genai.Text(p.UserQuery))
modelParts = append(modelParts, genai.Text(p.BotResponse))
}

content := &genai.CachedContent{
Model: GeminiModel,
Expiration: genai.ExpireTimeOrTTL{TTL: 60 * time.Minute},
Contents: []*genai.Content{
{
Role: "user",
Parts: userParts,
},
{
Role: "model",
Parts: modelParts,
schema := &genai.Schema{
Type: genai.TypeObject,
Properties: map[string]*genai.Schema{
"location": {
Type: genai.TypeString,
Description: "the place the user wants to go, e.g. Crete, Greece",
},
},
Required: []string{"location"},
}
result, err := client.CreateCachedContent(ctx, content)

redditTool := &genai.Tool{
FunctionDeclarations: []*genai.FunctionDeclaration{{
Name: funcName,
Description: "Get Reddit posts about a location from the Travel subreddit",
Parameters: schema,
}},
}

model := client.GenerativeModel(GeminiModel)
model.Tools = []*genai.Tool{redditTool}

session := model.StartChat()

res, err := session.SendMessage(ctx, genai.Text(query))
if err != nil {
return "", fmt.Errorf("CreateCachedContent: %w", err)
return "", nil
}
resourceName := result.Name

return resourceName, nil
}
part := res.Candidates[0].Content.Parts[0]
funcCall, ok := part.(genai.FunctionCall)
if !ok {
return "", fmt.Errorf("expected function call: %v", part)
}
if funcCall.Name != funcName {
return "", fmt.Errorf("expected %s, got: %v", funcName, funcCall.Name)
}
locArg, ok := funcCall.Args["location"].(string)
if !ok {
return "", fmt.Errorf("expected string, got: %v", funcCall.Args["location"])
}

func trimContext() (last string) {
sep := "###"
convos := strings.Split(cachedContext, sep)
length := len(convos)
if len(convos) > 3 {
last = strings.Join(convos[length-3:length-1], sep)
redditData, err := getRedditPosts(locArg)
if err != nil {
return "", err
}
return last

res, err = session.SendMessage(ctx, genai.FunctionResponse{
Name: redditTool.FunctionDeclarations[0].Name,
Response: map[string]any{
"output": redditData,
},
})
if err != nil {
return "", err
}

output := string(res.Candidates[0].Content.Parts[0].(genai.Text))
return output, nil
}
6 changes: 5 additions & 1 deletion server/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ require (
github.com/hashicorp/go-retryablehttp v0.7.4
)

require cloud.google.com/go/longrunning v0.6.1 // indirect
require (
cloud.google.com/go/longrunning v0.6.1 // indirect
github.com/google/go-querystring v1.0.0 // indirect
github.com/vartanbeno/go-reddit/v2 v2.0.1 // indirect
)

require (
cloud.google.com/go v0.116.0 // indirect
Expand Down
Loading

0 comments on commit f6e8015

Please sign in to comment.