Skip to content

Commit

Permalink
feat: created predict endpoint
Browse files Browse the repository at this point in the history
feat: created predict endpoint
  • Loading branch information
telpirion authored Nov 25, 2024
2 parents e29574b + ecd3c3e commit e6c1555
Show file tree
Hide file tree
Showing 11 changed files with 184 additions and 50 deletions.
4 changes: 3 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ COPY site/css ./site/css
COPY site/html ./site/html
COPY prompts ./server/templates
COPY server/favicon.ico ./server/favicon.ico
COPY server/* ./server
COPY server/generated ./server/generated
COPY server/ai ./server/ai
COPY server/*.go ./server

COPY server/go.mod server/go.sum ./server/
WORKDIR /server
Expand Down
25 changes: 25 additions & 0 deletions docs/server.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,28 @@ $ docker tag myherodotus us-west1-docker.pkg.dev/${PROJECT_ID}/my-herodotus/base
$ docker push us-west1-docker.pkg.dev/${PROJECT_ID}/my-herodotus/base-image:${SEMVER}
```

## Get predictions directly from API

The MyHerodotus app exposes an API endpoint, `/predict`, that allows callers to send
raw prediction requests to the AI system.

The following code sample demonstrates how to get a simple prediction from the `predict`
endpoint using `curl`. This assumes that the MyHerodotus app is running locally and
listeningon port `:8080`.

```sh
curl --header "Content-Type: application/json" \
--request POST \
--data '{"message":"I want to go to Greece","model":"gemini"}' \
http://localhost:8080/predict
```

The following code sample demonstrates how to get a simple prediction from the `predict`
endpoint of the deployed Herodotus app using `curl`.

```sh
curl --header "Content-Type: application/json" \
--request POST \
--data '{"message":"I want to go to Greece","model":"gemini"}' \
https://myherodotus-1025771077852.us-west1.run.app/predict
```
101 changes: 80 additions & 21 deletions server/vertex.go → server/ai/vertex.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package main
package ai

import (
"bytes"
Expand All @@ -17,6 +17,8 @@ import (
"cloud.google.com/go/vertexai/genai"
"google.golang.org/api/option"
"google.golang.org/protobuf/types/known/structpb"

"github.com/telpirion/MyHerodotus/generated"
)

const (
Expand All @@ -29,7 +31,30 @@ const (
MaxGemmaTokens int32 = 2048
)

var cachedContext string = ""
var (
cachedContext = ""
convoContext = ""
)

type Modality int

const (
Gemini Modality = iota
GeminiTuned
Gemma
AgentAssisted
EmbeddingsAssisted
)

var (
modalitiesMap = map[string]Modality{
"gemini": Gemini,
"gemini-tuned": GeminiTuned,
"gemma": Gemma,
"agent-assisted": AgentAssisted,
"embeddings-assisted": EmbeddingsAssisted,
}
)

type MinCacheNotReachedError struct {
ConversationCount int
Expand All @@ -45,8 +70,29 @@ type promptInput struct {
History string
}

// getTokenCount uses the Gemini tokenizer to count the tokens in some text.
func getTokenCount(text string) (int32, error) {
func Predict(query, modality, projectID string) (response string, templateName string, err error) {

switch modalitiesMap[strings.ToLower(modality)] {
case Gemini:
response, err = textPredictGemini(query, projectID, Gemini)
case Gemma:
response, err = textPredictGemma(query, projectID)
case GeminiTuned:
response, err = textPredictGemini(query, projectID, GeminiTuned)
default:
response, err = textPredictGemini(query, projectID, Gemini)
}

if err != nil {
return "", "", nil
}

cachedContext += fmt.Sprintf("### Human: %s\n### Assistant: %s\n", query, response)
return response, templateName, nil
}

// GetTokenCount uses the Gemini tokenizer to count the tokens in some text.
func GetTokenCount(text, projectID string) (int32, error) {
location := "us-west1"
ctx := context.Background()
client, err := genai.NewClient(ctx, projectID, location)
Expand All @@ -65,9 +111,9 @@ func getTokenCount(text string) (int32, error) {
return resp.TotalTokens, nil
}

// setConversationContext creates string out of past conversation between user and model.
// SetConversationContext creates string out of past conversation between user and model.
// This conversation history is used as grounding for the prompt template.
func setConversationContext(convoHistory []ConversationBit) error {
func SetConversationContext(convoHistory []generated.ConversationBit) error {
tmp, err := template.ParseFiles(HistoryTemplate)
if err != nil {
return err
Expand All @@ -93,15 +139,15 @@ func extractAnswer(response string) string {
}

// createPrompt generates a new prompt based upon the stored prompt template.
func createPrompt(message, templateName string) (string, error) {
func createPrompt(message, templateName, history string) (string, error) {
tmp, err := template.ParseFiles(templateName)
if err != nil {
return "", nil
}

promptInputs := promptInput{
Query: message,
History: cachedContext,
History: history,
}

var buf bytes.Buffer
Expand All @@ -122,25 +168,33 @@ func textPredictGemma(message, projectID string) (string, error) {
apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
if err != nil {
LogError(fmt.Sprintf("unable to create prediction client: %v\n", err))
return "", err
}
defer client.Close()

parameters := map[string]interface{}{}

prompt, err := createPrompt(message, GemmaTemplate)
prompt, err := createPrompt(message, GemmaTemplate, cachedContext)
if err != nil {
LogError(fmt.Sprintf("unable to create Gemma prompt: %v\n", err))
return "", err
}

tokenCount, err := GetTokenCount(prompt, projectID)
if err != nil {
return "", fmt.Errorf("error counting input tokens: %w", err)
}
if tokenCount > MaxGemmaTokens {
prompt, err = createPrompt(message, GemmaTemplate, trimContext())
}
if err != nil {
prompt = message
}

promptValue, err := structpb.NewValue(map[string]interface{}{
"inputs": prompt,
"parameters": parameters,
})
if err != nil {
LogError(fmt.Sprintf("unable to create prompt value: %v\n", err))
return "", err
}

Expand All @@ -151,7 +205,6 @@ func textPredictGemma(message, projectID string) (string, error) {

resp, err := client.Predict(ctx, req)
if err != nil {
LogError(fmt.Sprintf("unable to make prediction: %v\n", err))
return "", err
}

Expand All @@ -163,19 +216,18 @@ func textPredictGemma(message, projectID string) (string, error) {
}

// textPredictGemini generates text using a Gemini 1.5 Flash model
func textPredictGemini(message, projectID, modelVersion string) (string, error) {
func textPredictGemini(message, projectID string, modality Modality) (string, error) {
ctx := context.Background()
location := "us-west1"

client, err := genai.NewClient(ctx, projectID, location)
if err != nil {
LogError(fmt.Sprintf("unable to create genai client: %v\n", err))
return "", err
}
defer client.Close()

modelName := GeminiModel
if modelVersion == "gemini-tuned" {
if modality == GeminiTuned {
endpointID := os.Getenv("TUNED_MODEL_ENDPOINT_ID")
modelName = fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID)
}
Expand All @@ -185,21 +237,18 @@ func textPredictGemini(message, projectID, modelVersion string) (string, error)
llm.CachedContentName = convoContext
}

prompt, err := createPrompt(message, GeminiTemplate)
prompt, err := createPrompt(message, GeminiTemplate, cachedContext)
if err != nil {
LogError(fmt.Sprintf("unable to create Gemini prompt: %v\n", err))
return "", err
}

resp, err := llm.GenerateContent(ctx, genai.Text(prompt))
if err != nil {
LogError(fmt.Sprintf("unable to generate content: %v\n", err))
return "", err
}

candidate, err := getCandidate(resp)
if err != nil {
LogError(err.Error())
return "I'm not sure how to answer that. Would you please repeat the question?", nil
}
return extractAnswer(candidate), nil
Expand All @@ -224,7 +273,7 @@ func getCandidate(resp *genai.GenerateContentResponse) (string, error) {

// storeConversationContext uploads past user conversations with the model into a Gen AI context.
// This context is used when the model is answering questions from the user.
func storeConversationContext(conversationHistory []ConversationBit, projectID string) (string, error) {
func StoreConversationContext(conversationHistory []generated.ConversationBit, projectID string) (string, error) {
if len(conversationHistory) < MinimumConversationNum {
return "", &MinCacheNotReachedError{ConversationCount: len(conversationHistory)}
}
Expand Down Expand Up @@ -266,3 +315,13 @@ func storeConversationContext(conversationHistory []ConversationBit, projectID s

return resourceName, nil
}

func trimContext() (last string) {
sep := "###"
convos := strings.Split(cachedContext, sep)
length := len(convos)
if len(convos) > 3 {
last = strings.Join(convos[length-3:length-1], sep)
}
return last
}
4 changes: 2 additions & 2 deletions server/vertex_test.go → server/ai/vertex_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package main
package ai

import (
"strings"
Expand Down Expand Up @@ -28,7 +28,7 @@ func TestSetConversationContext(t *testing.T) {
BotResponse: "test bot response 2",
},
}
err := setConversationContext(convoHistory)
err := SetConversationContext(convoHistory)
if err != nil {
t.Fatal(err)
}
Expand Down
12 changes: 7 additions & 5 deletions server/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import (

"cloud.google.com/go/firestore"
"google.golang.org/api/iterator"

"github.com/telpirion/MyHerodotus/generated"
)

const (
Expand All @@ -55,10 +57,10 @@ var CollectionName string = "HerodotusDev"
*/
type ConversationHistory struct {
UserEmail string
Conversations []ConversationBit
Conversations []generated.ConversationBit
}

func saveConversation(convo ConversationBit, userEmail, projectID string) (string, error) {
func saveConversation(convo generated.ConversationBit, userEmail, projectID string) (string, error) {
ctx := context.Background()

// Get CollectionName for running in staging or prod
Expand Down Expand Up @@ -104,9 +106,9 @@ func updateConversation(documentId, userEmail, rating, projectID string) error {
return nil
}

func getConversation(userEmail, projectID string) ([]ConversationBit, error) {
func getConversation(userEmail, projectID string) ([]generated.ConversationBit, error) {
ctx := context.Background()
conversations := []ConversationBit{}
conversations := []generated.ConversationBit{}
client, err := firestore.NewClientWithDatabase(ctx, projectID, DBName)
if err != nil {
LogError(fmt.Sprintf("firestore.Client: %v\n", err))
Expand All @@ -126,7 +128,7 @@ func getConversation(userEmail, projectID string) ([]ConversationBit, error) {
LogError(fmt.Sprintf("Firestore Iterator: %v\n", err))
return conversations, err
}
var convo ConversationBit
var convo generated.ConversationBit
err = doc.DataTo(&convo)
if err != nil {
LogError(fmt.Sprintf("Firestore document unmarshaling: %v\n", err))
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion server/go.mod
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module my-herodotus
module github.com/telpirion/MyHerodotus

go 1.23.0

Expand Down
Binary file added server/main
Binary file not shown.
Loading

0 comments on commit e6c1555

Please sign in to comment.