diff --git a/internal/app/store/sql_store.go b/internal/app/store/sql_store.go index 941bbb3..9d61b7e 100644 --- a/internal/app/store/sql_store.go +++ b/internal/app/store/sql_store.go @@ -51,8 +51,9 @@ func validateTableName(name string) error { return nil } -func genSortString(sortFields []string) (string, error) { +func genSortString(sortFields []string, mapper fieldMapper) (string, error) { var buf bytes.Buffer + var err error for i, field := range sortFields { if i > 0 { @@ -62,20 +63,30 @@ func genSortString(sortFields []string) (string, error) { lower := strings.ToLower(field) if strings.HasSuffix(lower, ":"+SORT_DESCENDING) { field = strings.TrimSpace(field[:len(field)-len(":"+SORT_DESCENDING)]) - mapped, err := sqliteFieldMapper(field) - if err != nil { - return "", err + + mapped := field + if mapper != nil { + mapped, err = mapper(field) + if err != nil { + return "", err + } } buf.WriteString(mapped) buf.WriteString(" DESC") + } else { if strings.HasSuffix(lower, ":"+SORT_ASCENDING) { // :ASC is optional field = strings.TrimSpace(field[:len(field)-len(":"+SORT_ASCENDING)]) } - mapped, err := sqliteFieldMapper(field) - if err != nil { - return "", err + + mapped := field + if mapper != nil { + mapped, err = mapper(field) + if err != nil { + return "", err + } } + buf.WriteString(mapped) buf.WriteString(" ASC") } @@ -157,11 +168,49 @@ func (s *SqlStore) initStore() error { return fmt.Errorf("error creating table %s: %w", table, err) } s.Info().Msgf("Created table %s", table) + + unquotedTable := strings.Trim(table, "'") + if storeType.Indexes != nil { + for _, index := range storeType.Indexes { + + indexStmt, err := createIndexStmt(unquotedTable, index) + if err != nil { + return err + } + + _, err = s.db.Exec(indexStmt) + s.Trace().Msgf("indexStmt: %s", indexStmt) + if err != nil { + return fmt.Errorf("error creating index on %s: %w", unquotedTable, err) + } + } + } } return nil } +func createIndexStmt(unquotedTableName string, index utils.Index) (string, error) { + mappedColumns, err := genSortString(index.Fields, sqliteFieldMapper) + if err != nil { + return "", fmt.Errorf("error generating index columns for table %s: %w", unquotedTableName, err) + } + unmappedColumns, err := genSortString(index.Fields, nil) + if err != nil { + return "", fmt.Errorf("error generating index columns for table %s: %w", unquotedTableName, err) + } + indexName := fmt.Sprintf("index_%s_%s", unquotedTableName, strings.ReplaceAll(unmappedColumns, ", ", "_")) + indexName = strings.ReplaceAll(indexName, " ", "_") + + unique := " " + if index.Unique { + unique = " UNIQUE " + } + + indexStmt := fmt.Sprintf("CREATE%sINDEX IF NOT EXISTS '%s' ON '%s' (%s)", unique, indexName, unquotedTableName, mappedColumns) + return indexStmt, nil +} + // Insert a new entry in the store func (s *SqlStore) Insert(table string, entry *Entry) (EntryId, error) { if err := s.initialize(); err != nil { @@ -185,12 +234,12 @@ func (s *SqlStore) Insert(table string, entry *Entry) (EntryId, error) { createStmt := "INSERT INTO " + table + " (_version, _created_by, _updated_by, _created_at, _updated_at, _json) VALUES (?, ?, ?, ?, ?, ?)" result, err := s.db.Exec(createStmt, entry.Version, entry.CreatedBy, entry.UpdatedBy, entry.CreatedAt.UnixMilli(), entry.UpdatedAt.UnixMilli(), dataJson) if err != nil { - return -1, nil + return -1, err } insertId, err := result.LastInsertId() if err != nil { - return -1, nil + return -1, err } return EntryId(insertId), nil } @@ -258,7 +307,7 @@ func (s *SqlStore) Select(table string, filter map[string]any, sort []string, of var sortStr string if len(sort) > 0 { - sortStr, err = genSortString(sort) + sortStr, err = genSortString(sort, sqliteFieldMapper) if err != nil { return nil, err } diff --git a/internal/app/store/sql_store_test.go b/internal/app/store/sql_store_test.go index f9a5814..ca10ddf 100644 --- a/internal/app/store/sql_store_test.go +++ b/internal/app/store/sql_store_test.go @@ -5,6 +5,8 @@ package store import ( "testing" + + "github.com/claceio/clace/internal/utils" ) func TestGenTableName(t *testing.T) { @@ -27,13 +29,55 @@ func TestGenTableName(t *testing.T) { func TestGenSortString(t *testing.T) { sort := []string{"field1:asc", "field2:DEsc", "_id"} - expected := "_json ->> 'field1' ASC, _json ->> 'field2' DESC, _id ASC" - result, err := genSortString(sort) + + result, err := genSortString(sort, sqliteFieldMapper) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if result != expected { + t.Errorf("Expected %s, but got %s", expected, result) + } + + result, err = genSortString(sort, nil) // no field name mapping + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if result != "field1 ASC, field2 DESC, _id ASC" { + t.Errorf("Expected %s, but got %s", expected, result) + } +} + +// test for createIndexStmt +func TestCreateIndexStmt(t *testing.T) { + table := "prefix_table" + index := utils.Index{ + Fields: []string{"field:asc", "_id:desc"}, + Unique: false, + } + + result, err := createIndexStmt(table, index) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + expected := "CREATE INDEX IF NOT EXISTS 'index_prefix_table_field_ASC__id_DESC' ON 'prefix_table' (_json ->> 'field' ASC, _id DESC)" + if result != expected { + t.Errorf("Expected %s, but got %s", expected, result) + } + + index = utils.Index{ + Fields: []string{"map.key", "_id:desc"}, + Unique: true, + } + result, err = createIndexStmt(table, index) if err != nil { t.Errorf("Unexpected error: %v", err) } + expected = "CREATE UNIQUE INDEX IF NOT EXISTS 'index_prefix_table_map.key_ASC__id_DESC' ON 'prefix_table' (_json ->> 'map.key' ASC, _id DESC)" if result != expected { t.Errorf("Expected %s, but got %s", expected, result) } diff --git a/internal/app/tests/store_test.go b/internal/app/tests/store_test.go index cec29b4..3a3b15d 100644 --- a/internal/app/tests/store_test.go +++ b/internal/app/tests/store_test.go @@ -47,6 +47,9 @@ def handler(req): ret3 = store.insert(table.test1, myt) if not ret3: return {"error": ret3.error} + ret4 = store.insert(table.test1, myt) + if ret4: # Expect to fail + return {"error": "Expected duplicate insert to fail"} id = ret.value ret = store.select_by_id(table.test1, id) @@ -116,7 +119,10 @@ type("test1", fields=[ field("abool", BOOLEAN), field("alist", LIST), field("adict", DICT), -])`, +], +indexes=[ + index(["aint:asc", "astring:desc"], unique=True) + ])`, "index.go.html": ``, }