Skip to content

Commit

Permalink
Added store plugin APIs and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
akclace committed Jan 19, 2024
1 parent ae55825 commit 8f69178
Show file tree
Hide file tree
Showing 18 changed files with 319 additions and 204 deletions.
20 changes: 10 additions & 10 deletions internal/app/app_plugins.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,29 +35,29 @@ func NewAppPlugins(app *App, pluginConfig map[string]utils.PluginSettings, appAc
}
}

func (p *AppPlugins) GetPlugin(pluginInfo *PluginInfo, accountName string) (any, error) {
func (p *AppPlugins) GetPlugin(pluginInfo *utils.PluginInfo, accountName string) (any, error) {
p.Lock()
defer p.Unlock()

plugin, ok := p.plugins[pluginInfo.moduleName]
plugin, ok := p.plugins[pluginInfo.ModuleName]
if ok {
// Already initialized, use that
return plugin, nil
}

// If account name is specified, use that to lookup the account map
accountLookupName := pluginInfo.pluginPath
accountLookupName := pluginInfo.PluginPath
if accountName != "" {
accountLookupName = fmt.Sprintf("%s%s%s", pluginInfo.pluginPath, util.ACCOUNT_SEPERATOR, accountName)
accountLookupName = fmt.Sprintf("%s%s%s", pluginInfo.PluginPath, util.ACCOUNT_SEPERATOR, accountName)
}

pluginAccount := pluginInfo.pluginPath
pluginAccount := pluginInfo.PluginPath
_, ok = p.accountMap[accountLookupName]
if ok {
pluginAccount = p.accountMap[accountLookupName]
// If it is just account name, make it full plugin path
if !strings.Contains(pluginAccount, util.ACCOUNT_SEPERATOR) {
pluginAccount = fmt.Sprintf("%s%s%s", pluginInfo.pluginPath, util.ACCOUNT_SEPERATOR, pluginAccount)
pluginAccount = fmt.Sprintf("%s%s%s", pluginInfo.PluginPath, util.ACCOUNT_SEPERATOR, pluginAccount)
}
}

Expand All @@ -66,17 +66,17 @@ func (p *AppPlugins) GetPlugin(pluginInfo *PluginInfo, accountName string) (any,
appConfig = p.pluginConfig[pluginAccount]
}

pluginContext := &PluginContext{
pluginContext := &utils.PluginContext{
Logger: p.app.Logger,
AppId: p.app.AppEntry.Id,
StoreInfo: p.app.storeInfo,
Config: appConfig,
}
plugin, err := pluginInfo.builder(pluginContext)
plugin, err := pluginInfo.Builder(pluginContext)
if err != nil {
return nil, fmt.Errorf("error creating plugin %s: %w", pluginInfo.funcName, err)
return nil, fmt.Errorf("error creating plugin %s: %w", pluginInfo.FuncName, err)
}

p.plugins[pluginInfo.pluginPath] = plugin
p.plugins[pluginInfo.PluginPath] = plugin
return plugin, nil
}
48 changes: 24 additions & 24 deletions internal/app/app_plugins_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,70 +33,70 @@ func TestGetPlugin(t *testing.T) {
appPlugins := NewAppPlugins(app, pluginConfig, appAccounts)

// Define the pluginInfo and accountName for testing
pluginInfo := &PluginInfo{
moduleName: "plugin1",
pluginPath: "plugin1.in",
funcName: "Plugin1Builder",
pluginInfo := &utils.PluginInfo{
ModuleName: "plugin1",
PluginPath: "plugin1.in",
FuncName: "Plugin1Builder",
}

// Test with no account, no account link
pluginInfo.builder = func(pluginContext *PluginContext) (any, error) {
pluginInfo.Builder = func(pluginContext *utils.PluginContext) (any, error) {
testutil.AssertEqualsString(t, "match key", "v1", pluginContext.Config["key"].(string))
return nil, nil
}
plugin, err := appPlugins.GetPlugin(pluginInfo, "")
testutil.AssertNoError(t, err)
if plugin != appPlugins.plugins[pluginInfo.moduleName] {
t.Errorf("Expected %v, got %v", appPlugins.plugins[pluginInfo.moduleName], plugin)
if plugin != appPlugins.plugins[pluginInfo.ModuleName] {
t.Errorf("Expected %v, got %v", appPlugins.plugins[pluginInfo.ModuleName], plugin)
}

// Test with no account, with account link
pluginInfo.moduleName = "plugin2"
pluginInfo.pluginPath = "plugin2.in"
pluginInfo.builder = func(pluginContext *PluginContext) (any, error) {
pluginInfo.ModuleName = "plugin2"
pluginInfo.PluginPath = "plugin2.in"
pluginInfo.Builder = func(pluginContext *utils.PluginContext) (any, error) {
testutil.AssertEqualsString(t, "match key", "v5", pluginContext.Config["key"].(string))
return nil, nil
}
plugin, err = appPlugins.GetPlugin(pluginInfo, "")
testutil.AssertNoError(t, err)
if plugin != appPlugins.plugins[pluginInfo.moduleName] {
t.Errorf("Expected %v, got %v", appPlugins.plugins[pluginInfo.moduleName], plugin)
if plugin != appPlugins.plugins[pluginInfo.ModuleName] {
t.Errorf("Expected %v, got %v", appPlugins.plugins[pluginInfo.ModuleName], plugin)
}

// Test with account, with no account link
pluginInfo.pluginPath = "plugin2.in#account1"
pluginInfo.builder = func(pluginContext *PluginContext) (any, error) {
pluginInfo.PluginPath = "plugin2.in#account1"
pluginInfo.Builder = func(pluginContext *utils.PluginContext) (any, error) {
testutil.AssertEqualsString(t, "match key", "v4", pluginContext.Config["key"].(string))
return nil, nil
}
plugin, err = appPlugins.GetPlugin(pluginInfo, "")
testutil.AssertNoError(t, err)
if plugin != appPlugins.plugins[pluginInfo.moduleName] {
t.Errorf("Expected %v, got %v", appPlugins.plugins[pluginInfo.moduleName], plugin)
if plugin != appPlugins.plugins[pluginInfo.ModuleName] {
t.Errorf("Expected %v, got %v", appPlugins.plugins[pluginInfo.ModuleName], plugin)
}

// Test with account, with account link
pluginInfo.pluginPath = "plugin2.in#account2"
pluginInfo.builder = func(pluginContext *PluginContext) (any, error) {
pluginInfo.PluginPath = "plugin2.in#account2"
pluginInfo.Builder = func(pluginContext *utils.PluginContext) (any, error) {
testutil.AssertEqualsString(t, "match key", "v6", pluginContext.Config["key"].(string))
return nil, nil
}
plugin, err = appPlugins.GetPlugin(pluginInfo, "")
testutil.AssertNoError(t, err)
if plugin != appPlugins.plugins[pluginInfo.moduleName] {
t.Errorf("Expected %v, got %v", appPlugins.plugins[pluginInfo.moduleName], plugin)
if plugin != appPlugins.plugins[pluginInfo.ModuleName] {
t.Errorf("Expected %v, got %v", appPlugins.plugins[pluginInfo.ModuleName], plugin)
}

// Test with invalid account
pluginInfo.pluginPath = "plugin2.in#invalid"
pluginInfo.builder = func(pluginContext *PluginContext) (any, error) {
pluginInfo.PluginPath = "plugin2.in#invalid"
pluginInfo.Builder = func(pluginContext *utils.PluginContext) (any, error) {
// Config should have no entries
testutil.AssertEqualsInt(t, "match key", 0, len(pluginContext.Config))
return nil, nil
}
plugin, err = appPlugins.GetPlugin(pluginInfo, "")
testutil.AssertNoError(t, err)
if plugin != appPlugins.plugins[pluginInfo.moduleName] {
t.Errorf("Expected %v, got %v", appPlugins.plugins[pluginInfo.moduleName], plugin)
if plugin != appPlugins.plugins[pluginInfo.ModuleName] {
t.Errorf("Expected %v, got %v", appPlugins.plugins[pluginInfo.ModuleName], plugin)
}
}
73 changes: 22 additions & 51 deletions internal/app/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,71 +19,42 @@ import (
"go.starlark.net/starlarkstruct"
)

type PluginContext struct {
Logger *utils.Logger
AppId utils.AppId
StoreInfo *utils.StoreInfo
Config utils.PluginSettings
}

type NewPluginFunc func(pluginContext *PluginContext) (any, error)

var (
loaderInitMutex sync.Mutex
builtInPlugins map[string]PluginMap
builtInPlugins map[string]utils.PluginMap
)

func init() {
builtInPlugins = make(map[string]PluginMap)
builtInPlugins = make(map[string]utils.PluginMap)
}

// RegisterPlugin registers a plugin with Clace
func RegisterPlugin(name string, builder NewPluginFunc, funcs []PluginFunc) {
func RegisterPlugin(name string, builder utils.NewPluginFunc, funcs []utils.PluginFunc) {
loaderInitMutex.Lock()
defer loaderInitMutex.Unlock()

pluginPath := fmt.Sprintf("%s.%s", name, util.BUILTIN_PLUGIN_SUFFIX)
pluginMap := make(PluginMap)
pluginMap := make(utils.PluginMap)
for _, f := range funcs {
info := PluginInfo{
moduleName: name,
pluginPath: pluginPath,
funcName: f.name,
isRead: f.isRead,
handlerName: f.functionName,
builder: builder,
info := utils.PluginInfo{
ModuleName: name,
PluginPath: pluginPath,
FuncName: f.Name,
IsRead: f.IsRead,
HandlerName: f.FunctionName,
Builder: builder,
}

pluginMap[f.name] = &info
pluginMap[f.Name] = &info
}

builtInPlugins[pluginPath] = pluginMap
}

// PluginMap is the plugin function mapping to PluginFuncs
type PluginMap map[string]*PluginInfo

// PluginFunc is the Clace plugin function mapping to starlark function
type PluginFunc struct {
name string
isRead bool
functionName string
}

// PluginFuncInfo is the Clace plugin function info for the starlark function
type PluginInfo struct {
moduleName string // exec
pluginPath string // exec.in
funcName string // run
isRead bool
handlerName string
builder NewPluginFunc
}

func CreatePluginApi(
f func(thread *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error),
isRead bool,
) PluginFunc {
) utils.PluginFunc {

funcVal := runtime.FuncForPC(reflect.ValueOf(f).Pointer())
if funcVal == nil {
Expand All @@ -101,7 +72,7 @@ func CreatePluginApi(
func CreatePluginApiName(
f func(thread *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error),
isRead bool,
name string) PluginFunc {
name string) utils.PluginFunc {
funcVal := runtime.FuncForPC(reflect.ValueOf(f).Pointer())
if funcVal == nil {
panic(fmt.Errorf("function %s not found during plugin register", name))
Expand All @@ -119,10 +90,10 @@ func CreatePluginApiName(
panic(fmt.Errorf("function %s is not an exported method during plugin register", funcName))
}

return PluginFunc{
name: name,
isRead: isRead,
functionName: funcName,
return utils.PluginFunc{
Name: name,
IsRead: isRead,
FunctionName: funcName,
}
}

Expand Down Expand Up @@ -169,7 +140,7 @@ func parseModulePath(moduleFullPath string) (string, string, string) {
}

// pluginLookup looks up the plugin. Audit checks need to be done by the caller
func (a *App) pluginLookup(_ *starlark.Thread, module string) (PluginMap, error) {
func (a *App) pluginLookup(_ *starlark.Thread, module string) (utils.PluginMap, error) {
pluginDict, ok := builtInPlugins[module]
if !ok {
return nil, fmt.Errorf("module %s not found", module) // TODO extend loading
Expand All @@ -178,7 +149,7 @@ func (a *App) pluginLookup(_ *starlark.Thread, module string) (PluginMap, error)
return pluginDict, nil
}

func (a *App) pluginHook(modulePath, accountName, functionName string, pluginInfo *PluginInfo) *starlark.Builtin {
func (a *App) pluginHook(modulePath, accountName, functionName string, pluginInfo *utils.PluginInfo) *starlark.Builtin {
hook := func(thread *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
a.Trace().Msgf("Plugin called: %s.%s", modulePath, functionName)

Expand Down Expand Up @@ -220,7 +191,7 @@ func (a *App) pluginHook(modulePath, accountName, functionName string, pluginInf
isRead = *p.IsRead
} else {
// Use the plugin defined isRead value
isRead = pluginInfo.isRead
isRead = pluginInfo.IsRead
}

if !isRead {
Expand Down Expand Up @@ -255,7 +226,7 @@ func (a *App) pluginHook(modulePath, accountName, functionName string, pluginInf
}

// Get the plugin function using reflection
pluginValue := reflect.ValueOf(plugin).MethodByName(pluginInfo.handlerName)
pluginValue := reflect.ValueOf(plugin).MethodByName(pluginInfo.HandlerName)
if pluginValue.IsNil() {
return nil, fmt.Errorf("plugin func %s.%s cannot be resolved", modulePath, functionName)
}
Expand Down
Loading

0 comments on commit 8f69178

Please sign in to comment.