Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add preliminary parsing of parameters #13

Merged
merged 1 commit into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ func TestParseToolFile(t *testing.T) {
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
country:
type: string
description: some description
- name: country
type: string
description: some description
`,
wantSources: sources.Configs{
"my-pg-instance": sources.CloudSQLPgConfig{
Expand All @@ -211,8 +211,9 @@ func TestParseToolFile(t *testing.T) {
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
Parameters: map[string]tools.Parameter{
"country": {
Parameters: []tools.Parameter{
{
Name: "country",
Type: "string",
Description: "some description",
},
Expand Down
23 changes: 20 additions & 3 deletions internal/server/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"net/http"

"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/render"
)

Expand All @@ -28,8 +29,10 @@ func apiRouter(s *Server) chi.Router {

r.Get("/toolset/{toolsetName}", toolsetHandler(s))

// TODO: make this POST
r.Get("/tool/{toolName}", toolHandler(s))
r.Route("/tool/{toolName}", func(r chi.Router) {
r.Use(middleware.AllowContentType("application/json"))
r.Post("/", toolHandler(s))
})

return r
}
Expand All @@ -51,7 +54,21 @@ func toolHandler(s *Server) http.HandlerFunc {
return
}

res, err := tool.Invoke()
var data map[string]interface{}
if err := render.DecodeJSON(r.Body, &data); err != nil {
render.Status(r, http.StatusBadRequest)
return
}

params, err := tool.ParseParams(data)
if err != nil {
render.Status(r, http.StatusBadRequest)
// TODO: More robust error formatting (probably JSON)
render.PlainText(w, r, err.Error())
return
}

res, err := tool.Invoke(params)
if err != nil {
render.Status(r, http.StatusInternalServerError)
return
Expand Down
34 changes: 20 additions & 14 deletions internal/tools/cloud_sql_pg.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ const CloudSQLPgSQLGenericKind string = "cloud-sql-postgres-generic"
var _ Config = CloudSQLPgGenericConfig{}

type CloudSQLPgGenericConfig struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Source string `yaml:"source"`
Description string `yaml:"description"`
Statement string `yaml:"statement"`
Parameters map[string]Parameter `yaml:"parameters"`
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Source string `yaml:"source"`
Description string `yaml:"description"`
Statement string `yaml:"statement"`
Parameters []Parameter `yaml:"parameters"`
}

func (r CloudSQLPgGenericConfig) toolKind() string {
Expand All @@ -53,9 +53,10 @@ func (r CloudSQLPgGenericConfig) Initialize(srcs map[string]sources.Source) (Too

// finish tool setup
t := CloudSQLPgGenericTool{
Name: r.Name,
Kind: CloudSQLPgSQLGenericKind,
Source: s,
Name: r.Name,
Kind: CloudSQLPgSQLGenericKind,
Source: s,
Parameters: r.Parameters,
}
return t, nil
}
Expand All @@ -64,11 +65,16 @@ func (r CloudSQLPgGenericConfig) Initialize(srcs map[string]sources.Source) (Too
var _ Tool = CloudSQLPgGenericTool{}

type CloudSQLPgGenericTool struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Source sources.CloudSQLPgSource
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Source sources.CloudSQLPgSource
Parameters []Parameter `yaml:"parameters"`
}

func (t CloudSQLPgGenericTool) Invoke() (string, error) {
return fmt.Sprintf("Stub tool call for %q!", t.Name), nil
func (t CloudSQLPgGenericTool) Invoke(params []any) (string, error) {
return fmt.Sprintf("Stub tool call for %q! Parameters parsed: %q", t.Name, params), nil
}

func (t CloudSQLPgGenericTool) ParseParams(data map[string]any) ([]any, error) {
return parseParams(t.Parameters, data)
}
11 changes: 6 additions & 5 deletions internal/tools/cloud_sql_pg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ func TestParseFromYaml(t *testing.T) {
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
country:
type: string
description: some description
- name: country
type: string
description: some description
`,
want: tools.Configs{
"example_tool": tools.CloudSQLPgGenericConfig{
Expand All @@ -51,8 +51,9 @@ func TestParseFromYaml(t *testing.T) {
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
Parameters: map[string]tools.Parameter{
"country": {
Parameters: []tools.Parameter{
{
Name: "country",
Type: "string",
Description: "some description",
},
Expand Down
43 changes: 40 additions & 3 deletions internal/tools/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,50 @@ func (c *Configs) UnmarshalYAML(node *yaml.Node) error {
return nil
}

type Tool interface {
Invoke([]any) (string, error)
ParseParams(data map[string]any) ([]any, error)
}

type Parameter struct {
Name string `yaml:"name"`
Type string `yaml:"type"`
Description string `yaml:"description"`
Required bool `yaml:"required"`
Copy link
Contributor

@Yuan325 Yuan325 Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side note but I'm thinking having this might help improve latency too(?). Ideally, the LLM will prompt user to provide more information before moving on to the next steps since the field is required.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm removing this because I want to default for all parameters to be required -- I'm unsure of how to handle optional parameters in the SQL query

}

type Tool interface {
Invoke() (string, error)
// ParseTypeError is a customer error for incorrectly typed Parameters.
type ParseTypeError struct {
Name string
Type string
Value any
}

func (e ParseTypeError) Error() string {
return fmt.Sprintf("Error parsing parameter %q: %q not type %q", e.Name, e.Value, e.Type)
}

// ParseParams is a helper function for parsing Parameters from an arbitratyJSON object.
func parseParams(ps []Parameter, data map[string]any) ([]any, error) {
params := []any{}
for _, p := range ps {
v, ok := data[p.Name]
if !ok {
return nil, fmt.Errorf("Parameter %q is required!", p.Name)
}
switch p.Type {
case "string":
v, ok = v.(string)
case "integer":
v, ok = v.(int)
case "float":
v, ok = v.(float32)
case "boolean":
v, ok = v.(bool)
}
if !ok {
return nil, &ParseTypeError{p.Name, p.Type, v}
}
params = append(params, v)
}
return params, nil
}
Loading