Skip to content

Commit

Permalink
Add support to clone functions with overrides
Browse files Browse the repository at this point in the history
  • Loading branch information
kelveny committed Jun 5, 2021
1 parent 08a35e3 commit fafb07f
Show file tree
Hide file tree
Showing 11 changed files with 400 additions and 30 deletions.
16 changes: 11 additions & 5 deletions cmd/clzgenerator.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ type classMethodGenerator struct {
mockPkgName string // package name that mocking class resides
mockName string // the mocking composite class name

mothodsToClone []string // method function names that need to be cloned in mocking class
methodsToClone []string // method function names that need to be cloned in mocking class
methodsToMock []string // method function names that need to be mocked
}

Expand Down Expand Up @@ -64,8 +64,8 @@ func (g *classMethodGenerator) match(fnSpec *ast.FuncDecl) (bool, matchType) {
}

func (g *classMethodGenerator) matchMethod(fnName string) matchType {
if len(g.mothodsToClone) > 0 {
for _, name := range g.mothodsToClone {
if len(g.methodsToClone) > 0 {
for _, name := range g.methodsToClone {
// in format of methodName,pkg1=mockPkg1:pkg2=mockPkg2
name = strings.Split(name, ",")[0]
if name == fnName {
Expand All @@ -86,7 +86,7 @@ func (g *classMethodGenerator) matchMethod(fnName string) matchType {
}

func (g *classMethodGenerator) getMethodOverrides(fnName string) map[string]string {
for _, name := range g.mothodsToClone {
for _, name := range g.methodsToClone {
// in format of methodName,pkg1=mockPkg1:pkg2=mockPkg2
tokens := strings.Split(name, ",")
if tokens[0] == fnName {
Expand Down Expand Up @@ -181,7 +181,13 @@ func (g *classMethodGenerator) generateInternal(
if matchType == MATCH_CLONE {
// clone receiver-modified method
overrides := g.getMethodOverrides(fnSpec.Name.Name)
gogen.WriteFuncWithLocalOverrides(writer, fset, fnSpec, overrides)
gogen.WriteFuncWithLocalOverrides(
writer,
fset,
fnSpec,
fnSpec.Name.Name,
overrides,
)
} else if matchType == MATCH_MOCK {
// generate mocked method
g.composeMock(writer, fset, fnSpec)
Expand Down
142 changes: 142 additions & 0 deletions cmd/fnclonegenerator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package cmd

import (
"bytes"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"io"
"strings"

"github.com/kelveny/mockcompose/pkg/gogen"
"github.com/kelveny/mockcompose/pkg/logger"
)

type functionCloneGenerator struct {
mockPkgName string // package name that cloned functions reside
mockName string // name used to form generated file name
methodsToClone []string // function names that need to be cloned
}

// use compiler to enforce interface compliance
var _ parsedFileGenerator = (*functionCloneGenerator)(nil)

// match checks if a FuncDecl matches condition, if matched,
// modify function name
func (g *functionCloneGenerator) match(fnSpec *ast.FuncDecl) bool {
if fnSpec.Recv == nil {
return g.matchMethod(fnSpec.Name.Name)
}

return false
}

func (g *functionCloneGenerator) matchMethod(fnName string) bool {
if len(g.methodsToClone) > 0 {
for _, name := range g.methodsToClone {
// in format of methodName,pkg1=mockPkg1:pkg2=mockPkg2
name = strings.Split(name, ",")[0]
if name == fnName {
return true
}
}
}

return false
}

func (g *functionCloneGenerator) getMethodOverrides(fnName string) map[string]string {
for _, name := range g.methodsToClone {
// in format of methodName,pkg1=mockPkg1:pkg2=mockPkg2
tokens := strings.Split(name, ",")
if tokens[0] == fnName {
if len(tokens) > 1 {
pairs := strings.Split(tokens[1], ":")

overrides := make(map[string]string)
for _, pair := range pairs {
kv := strings.Split(pair, "=")
if len(kv) == 2 {
overrides[kv[0]] = kv[1]
} else {
logger.Log(logger.ERROR, "invalid configuration: -real %s\n", name)
}
}
return overrides
}
}
}
return nil
}

func (g *functionCloneGenerator) generate(
writer io.Writer,
file *ast.File,
) error {

var buf bytes.Buffer

fset := token.NewFileSet()
if g.generateInternal(&buf, fset, file) {
// reload generated content to process generated code the second time
f, err := parser.ParseFile(fset, "", buf.Bytes(), parser.ParseComments)
if err != nil {
logger.Log(logger.ERROR, "Internal error: %s\n\n%s\n", err, buf.String())
return err
}

// remove unused imports
var cleanedImports []gogen.ImportSpec = []gogen.ImportSpec{}
cleanedImports = gogen.CleanImports(f, cleanedImports)

// compose final output
fmt.Fprintf(writer, header, g.mockPkgName)

gogen.WriteImportDecls(writer, cleanedImports)
gogen.WriteFuncDecls(writer, fset, f)
}

return nil
}

func (g *functionCloneGenerator) generateInternal(
writer io.Writer,
fset *token.FileSet,
file *ast.File,
) bool {
found := false

writer.Write([]byte(fmt.Sprintf("package %s\n\n", g.mockPkgName)))

if len(file.Decls) > 0 {
for _, d := range file.Decls {
if fnSpec, ok := d.(*ast.FuncDecl); ok {
matched := g.match(fnSpec)
if matched {
found = true
}

if matched {
overrides := g.getMethodOverrides(fnSpec.Name.Name)
gogen.WriteFuncWithLocalOverrides(
writer,
fset,
fnSpec,
fnSpec.Name.Name+"_clone",
overrides,
)
}
} else {
// for any non-function declaration, export only imports
if dd, ok := d.(*ast.GenDecl); ok && dd.Tok == token.IMPORT {
format.Node(writer, fset, d)
writer.Write([]byte("\n\n"))
}
}
}
}

return found
}
44 changes: 28 additions & 16 deletions cmd/mockcompose.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func Execute() {
clzName: *options.clzName,
mockPkgName: *options.mockPkg,
mockName: *options.mockName,
mothodsToClone: options.methodsToClone,
methodsToClone: options.methodsToClone,
methodsToMock: options.methodsToMock,
}
} else if *intfName != "" {
Expand All @@ -100,27 +100,39 @@ func Execute() {
scanPackageToGenerate(g.(loadedPackageGenerator), options)
return
}

} else {
if len(methodsToMock) == 0 {
if len(methodsToClone) > 0 {
logger.Log(logger.ERROR, "Please use -real option together with -c option\n")
os.Exit(1)
}

logger.Log(logger.ERROR, "Please specify at least one mock function name with -mock option\n")
if len(methodsToMock) == 0 && len(methodsToClone) == 0 {
logger.Log(logger.ERROR, "no function to mock or clone\n")
os.Exit(1)
}

g = &functionMockGenerator{
mockPkgName: *options.mockPkg,
mockName: *options.mockName,
methodsToMock: *&options.methodsToMock,
if len(methodsToMock) > 0 && len(methodsToClone) > 0 {
logger.Log(logger.ERROR, "option -real and option -mock are exclusive in function clone generation\n")
os.Exit(1)
}

if *options.srcPkg != "" {
scanPackageToGenerate(g.(loadedPackageGenerator), options)
return
if len(methodsToClone) > 0 {
if *options.srcPkg != "" {
logger.Log(logger.PROMPT,
"No source package support in function clone generation, ignore source package %s\n",
*options.srcPkg)
}
g = &functionCloneGenerator{
mockPkgName: *options.mockPkg,
mockName: *options.mockName,
methodsToClone: *&options.methodsToClone,
}
} else {
g = &functionMockGenerator{
mockPkgName: *options.mockPkg,
mockName: *options.mockName,
methodsToMock: *&options.methodsToMock,
}

if *options.srcPkg != "" {
scanPackageToGenerate(g.(loadedPackageGenerator), options)
return
}
}
}

Expand Down
38 changes: 30 additions & 8 deletions pkg/gogen/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"go/token"
"go/types"
"io"
"sort"
"strings"

"github.com/kelveny/mockcompose/pkg/gosyntax"
Expand Down Expand Up @@ -155,21 +156,33 @@ func WriteFuncWithLocalOverrides(
writer io.Writer,
fset *token.FileSet,
fnSpec *ast.FuncDecl,
fnName string,
overrides map[string]string,
) {
if len(overrides) == 0 {
format.Node(writer, fset, fnSpec)
writer.Write([]byte("\n\n"))
} else {
fmt.Fprintf(
writer,
"func (%s) %s",
gosyntax.ParamListDeclString(fset, fnSpec.Recv),
fnSpec.Name.Name,
)
if fnSpec.Recv != nil {
fmt.Fprintf(
writer,
"func (%s) %s",
gosyntax.ParamListDeclString(fset, fnSpec.Recv),
fnName,
)
} else {
fmt.Fprintf(
writer,
"func %s",
fnName,
)
}

var b bytes.Buffer
// fnSpec.Type -> func(params) (rets)
format.Node(&b, fset, fnSpec.Type)

// use everything after func
fmt.Fprint(writer, string(b.Bytes()[4:]))

b.Reset()
Expand All @@ -186,9 +199,18 @@ func WriteFuncWithLocalOverrides(

func generateLocalOverrides(writer io.Writer, overrides map[string]string) {
if len(overrides) > 0 {
for k, v := range overrides {
keys := make([]string, len(overrides))
i := 0
for k := range overrides {
keys[i] = k
i++
}

sort.Strings(keys)

for _, k := range keys {
fmt.Fprintf(writer, `
%s := %s`, k, v)
%s := %s`, k, overrides[k])
}

fmt.Fprintf(writer, "\n")
Expand Down
44 changes: 44 additions & 0 deletions test/clonefn/fn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package clonefn

import (
"encoding/json"
"fmt"
)

func functionThatUsesGlobalFunction(
format string,
args ...interface{},
) string {
//
// skip fansy logic...
//

// call out to a global function in fmt package
return fmt.Sprintf(format, args...)
}

func functionThatUsesMultileGlobalFunctions(
format string,
args ...interface{},
) string {
//
// skip fansy logic...
//

// call out to a global function in fmt package and filepath package
b, _ := json.Marshal(format)
return string(b) + fmt.Sprintf(format, args...)
}

func functionThatUsesMultileGlobalFunctions2(
format string,
args ...interface{},
) string {
//
// skip fansy logic...
//

// call out to a global function in fmt package and filepath package
b, _ := json.Marshal(format)
return string(b) + fmt.Sprintf(format, args...)
}
31 changes: 31 additions & 0 deletions test/clonefn/fn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package clonefn

import (
"testing"

"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)

var jsonMock *mockJson = &mockJson{}
var fmtMock *mockFmt = &mockFmt{}

func TestClonedFuncs(t *testing.T) {
assert := require.New(t)

// setup function mocks
jsonMock.On("Marshal", mock.Anything).Return(([]byte)("mocked Marshal"), nil)
fmtMock.On("Sprintf", mock.Anything, mock.Anything).Return("mocked Sprintf")

// inside functionThatUsesMultileGlobalFunctions: fmt.Sprintf is mocked
assert.True(functionThatUsesGlobalFunction_clone("format", "value") == "mocked Sprintf")

// inside functionThatUsesMultileGlobalFunctions: both json.Marshal()
// and fmt.Sprintf are mocked
assert.True(functionThatUsesMultileGlobalFunctions_clone("format", "value") == "mocked Marshalmocked Sprintf")

// inside functionThatUsesMultileGlobalFunctions2: json.Marshal() is not mocked,
// fmt.Sprintf is mocked
assert.True(functionThatUsesMultileGlobalFunctions2_clone("format", "value") == "\"format\"mocked Sprintf")

}
Loading

0 comments on commit fafb07f

Please sign in to comment.