diff --git a/cluster/job.go b/cluster/job.go index d751abc..0c82671 100644 --- a/cluster/job.go +++ b/cluster/job.go @@ -58,10 +58,15 @@ func Schedule(pluginAPI JobPluginAPI, key string, config JobConfig, callback fun key = cronPrefix + key + mutex, err := NewMutex(pluginAPI, key) + if err != nil { + return nil, errors.Wrap(err, "failed to create job mutex") + } + job := &Job{ pluginAPI: pluginAPI, key: key, - mutex: NewMutex(pluginAPI, key), + mutex: mutex, config: config, callback: callback, stop: make(chan bool), diff --git a/cluster/job_example_test.go b/cluster/job_example_test.go index 13593b5..585a792 100644 --- a/cluster/job_example_test.go +++ b/cluster/job_example_test.go @@ -1,10 +1,14 @@ package cluster -import "time" +import ( + "time" + + "github.com/mattermost/mattermost-server/v5/plugin" +) func ExampleSchedule() { // Use p.API from your plugin instead. - pluginAPI := NewMockMutexPluginAPI(nil) + pluginAPI := plugin.API(nil) callback := func() { // periodic work to do diff --git a/cluster/job_test.go b/cluster/job_test.go index 5aa955a..e0988b3 100644 --- a/cluster/job_test.go +++ b/cluster/job_test.go @@ -6,38 +6,39 @@ import ( "testing" "time" + "github.com/mattermost/mattermost-server/v5/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -type MockJobPluginAPI struct { - *MockMutexPluginAPI -} +func TestSchedule(t *testing.T) { + t.Parallel() -func NewMockJobPluginAPI(t *testing.T) *MockJobPluginAPI { - return &MockJobPluginAPI{ - MockMutexPluginAPI: NewMockMutexPluginAPI(t), + makeKey := func() string { + return model.NewId() } -} -func TestSchedule(t *testing.T) { t.Run("invalid interval", func(t *testing.T) { - mockPluginAPI := NewMockJobPluginAPI(t) + t.Parallel() + + mockPluginAPI := newMockPluginAPI(t) - job, err := Schedule(mockPluginAPI, "key", JobConfig{}, func() {}) + job, err := Schedule(mockPluginAPI, makeKey(), JobConfig{}, func() {}) require.Error(t, err, "must specify non-zero job config interval") require.Nil(t, job) }) t.Run("single-threaded", func(t *testing.T) { - mockPluginAPI := NewMockJobPluginAPI(t) + t.Parallel() + + mockPluginAPI := newMockPluginAPI(t) count := new(int32) callback := func() { atomic.AddInt32(count, 1) } - job, err := Schedule(mockPluginAPI, "key", JobConfig{Interval: 100 * time.Millisecond}, callback) + job, err := Schedule(mockPluginAPI, makeKey(), JobConfig{Interval: 100 * time.Millisecond}, callback) require.NoError(t, err) require.NotNil(t, job) @@ -56,7 +57,9 @@ func TestSchedule(t *testing.T) { }) t.Run("multi-threaded, single job", func(t *testing.T) { - mockPluginAPI := NewMockJobPluginAPI(t) + t.Parallel() + + mockPluginAPI := newMockPluginAPI(t) count := new(int32) callback := func() { @@ -65,8 +68,10 @@ func TestSchedule(t *testing.T) { var jobs []*Job + key := makeKey() + for i := 0; i < 3; i++ { - job, err := Schedule(mockPluginAPI, "key", JobConfig{Interval: 100 * time.Millisecond}, callback) + job, err := Schedule(mockPluginAPI, key, JobConfig{Interval: 100 * time.Millisecond}, callback) require.NoError(t, err) require.NotNil(t, job) @@ -97,7 +102,9 @@ func TestSchedule(t *testing.T) { }) t.Run("multi-threaded, multiple jobs", func(t *testing.T) { - mockPluginAPI := NewMockJobPluginAPI(t) + t.Parallel() + + mockPluginAPI := newMockPluginAPI(t) countA := new(int32) callbackA := func() { @@ -109,15 +116,18 @@ func TestSchedule(t *testing.T) { atomic.AddInt32(countB, 1) } + keyA := makeKey() + keyB := makeKey() + var jobs []*Job for i := 0; i < 3; i++ { var key string var callback func() if i <= 1 { - key = "keyA" + key = keyA callback = callbackA } else { - key = "keyB" + key = keyB callback = callbackB } diff --git a/cluster/mock_plugin_api_test.go b/cluster/mock_plugin_api_test.go new file mode 100644 index 0000000..c3907c9 --- /dev/null +++ b/cluster/mock_plugin_api_test.go @@ -0,0 +1,87 @@ +package cluster + +import ( + "bytes" + "sync" + "testing" + + "github.com/mattermost/mattermost-server/v5/model" +) + +type mockPluginAPI struct { + t *testing.T + + lock sync.Mutex + keyValues map[string][]byte + failing bool +} + +func newMockPluginAPI(t *testing.T) *mockPluginAPI { + return &mockPluginAPI{ + t: t, + keyValues: make(map[string][]byte), + } +} + +func (pluginAPI *mockPluginAPI) setFailing(failing bool) { + pluginAPI.lock.Lock() + defer pluginAPI.lock.Unlock() + + pluginAPI.failing = failing +} + +func (pluginAPI *mockPluginAPI) clear() { + pluginAPI.lock.Lock() + defer pluginAPI.lock.Unlock() + + for k := range pluginAPI.keyValues { + delete(pluginAPI.keyValues, k) + } +} + +func (pluginAPI *mockPluginAPI) KVGet(key string) ([]byte, *model.AppError) { + pluginAPI.lock.Lock() + defer pluginAPI.lock.Unlock() + + if pluginAPI.failing { + return nil, &model.AppError{Message: "fake error"} + } + + return pluginAPI.keyValues[key], nil +} + +func (pluginAPI *mockPluginAPI) KVSetWithOptions(key string, value []byte, options model.PluginKVSetOptions) (bool, *model.AppError) { + pluginAPI.lock.Lock() + defer pluginAPI.lock.Unlock() + + if pluginAPI.failing { + return false, &model.AppError{Message: "fake error"} + } + + if options.Atomic { + if actualValue := pluginAPI.keyValues[key]; !bytes.Equal(actualValue, options.OldValue) { + return false, nil + } + } + + if value == nil { + delete(pluginAPI.keyValues, key) + } else { + pluginAPI.keyValues[key] = value + } + + return true, nil +} + +func (pluginAPI *mockPluginAPI) LogError(msg string, keyValuePairs ...interface{}) { + if pluginAPI.t == nil { + return + } + + pluginAPI.t.Helper() + + params := []interface{}{msg} + params = append(params, keyValuePairs...) + + pluginAPI.t.Log(params...) +} diff --git a/cluster/mutex.go b/cluster/mutex.go index 206720e..858cdcb 100644 --- a/cluster/mutex.go +++ b/cluster/mutex.go @@ -1,6 +1,7 @@ package cluster import ( + "context" "sync" "time" @@ -20,18 +21,6 @@ const ( // refreshInterval is the interval on which the mutex will be refreshed when locked refreshInterval = ttl / 2 - - // minWaitInterval is the minimum amount of time to wait between locking attempts - minWaitInterval = 1 * time.Second - - // maxWaitInterval is the maximum amount of time to wait between locking attempts - maxWaitInterval = 5 * time.Minute - - // pollWaitInterval is the usual time to wait between unsuccessful locking attempts - pollWaitInterval = 1 * time.Second - - // jitterWaitInterval is the amount of jitter to add when waiting to avoid thundering herds - jitterWaitInterval = minWaitInterval / 2 ) // MutexPluginAPI is the plugin API interface required to manage mutexes. @@ -60,14 +49,28 @@ type Mutex struct { refreshDone chan bool } -// NewMutex creates a mutex with the given name. -func NewMutex(pluginAPI MutexPluginAPI, key string) *Mutex { - key = mutexPrefix + key +// NewMutex creates a mutex with the given key name. +// +// Panics if key is empty. +func NewMutex(pluginAPI MutexPluginAPI, key string) (*Mutex, error) { + key, err := makeLockKey(key) + if err != nil { + return nil, err + } return &Mutex{ pluginAPI: pluginAPI, key: key, + }, nil +} + +// makeLockKey returns the prefixed key used to namespace mutex keys. +func makeLockKey(key string) (string, error) { + if len(key) == 0 { + return "", errors.New("must specify valid mutex key") } + + return mutexPrefix + key, nil } // lock makes a single attempt to atomically lock the mutex, returning true only if successful. @@ -85,8 +88,6 @@ func (m *Mutex) tryLock() (bool, error) { } // refreshLock rewrites the lock key value with a new expiry, returning true only if successful. -// -// Only call this while holding the lock. func (m *Mutex) refreshLock() error { ok, err := m.pluginAPI.KVSetWithOptions(m.key, []byte{1}, model.PluginKVSetOptions{ Atomic: true, @@ -105,10 +106,23 @@ func (m *Mutex) refreshLock() error { // Lock locks m. If the mutex is already locked by any plugin instance, including the current one, // the calling goroutine blocks until the mutex can be locked. func (m *Mutex) Lock() { + m.LockWithContext(context.Background()) +} + +// LockWithContext locks m unless the context is cancelled. If the mutex is already locked by any plugin +// instance, including the current one, the calling goroutine blocks until the mutex can be locked, +// or the context is cancelled. +// +// The mutex is locked only if a nil error is returned. +func (m *Mutex) LockWithContext(ctx context.Context) error { var waitInterval time.Duration for { - time.Sleep(waitInterval) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(waitInterval): + } locked, err := m.tryLock() if err != nil { @@ -144,7 +158,7 @@ func (m *Mutex) Lock() { m.refreshDone = done m.lock.Unlock() - return + return nil } } diff --git a/cluster/mutex_example_test.go b/cluster/mutex_example_test.go index 5f2d039..73bcbf9 100644 --- a/cluster/mutex_example_test.go +++ b/cluster/mutex_example_test.go @@ -9,7 +9,10 @@ func ExampleMutex() { // Use p.API from your plugin instead. pluginAPI := plugin.API(nil) - m := cluster.NewMutex(pluginAPI, "key") + m, err := cluster.NewMutex(pluginAPI, "key") + if err != nil { + panic(err) + } m.Lock() // critical section m.Unlock() diff --git a/cluster/mutex_test.go b/cluster/mutex_test.go index 7a5d827..65e6a98 100644 --- a/cluster/mutex_test.go +++ b/cluster/mutex_test.go @@ -1,8 +1,7 @@ package cluster import ( - "bytes" - "sync" + "context" "testing" "time" @@ -11,82 +10,43 @@ import ( "github.com/stretchr/testify/require" ) -type MockMutexPluginAPI struct { - t *testing.T - - lock sync.Mutex - keyValues map[string][]byte - failing bool -} - -func NewMockMutexPluginAPI(t *testing.T) *MockMutexPluginAPI { - return &MockMutexPluginAPI{ - t: t, - keyValues: make(map[string][]byte), +func mustMakeLockKey(key string) string { + key, err := makeLockKey(key) + if err != nil { + panic(err) } -} - -func (pluginAPI *MockMutexPluginAPI) setFailing(failing bool) { - pluginAPI.lock.Lock() - defer pluginAPI.lock.Unlock() - pluginAPI.failing = failing + return key } -func (pluginAPI *MockMutexPluginAPI) clear() { - pluginAPI.lock.Lock() - defer pluginAPI.lock.Unlock() - - for k := range pluginAPI.keyValues { - delete(pluginAPI.keyValues, k) - } -} - -func (pluginAPI *MockMutexPluginAPI) KVGet(key string) ([]byte, *model.AppError) { - pluginAPI.lock.Lock() - defer pluginAPI.lock.Unlock() - - if pluginAPI.failing { - return nil, &model.AppError{Message: "fake error"} +func mustNewMutex(pluginAPI MutexPluginAPI, key string) *Mutex { + m, err := NewMutex(pluginAPI, key) + if err != nil { + panic(err) } - return pluginAPI.keyValues[key], nil + return m } -func (pluginAPI *MockMutexPluginAPI) KVSetWithOptions(key string, value []byte, options model.PluginKVSetOptions) (bool, *model.AppError) { - pluginAPI.lock.Lock() - defer pluginAPI.lock.Unlock() - - if pluginAPI.failing { - return false, &model.AppError{Message: "fake error"} - } +func TestMakeLockKey(t *testing.T) { + t.Run("fails when empty", func(t *testing.T) { + key, err := makeLockKey("") + assert.Error(t, err) + assert.Empty(t, key) + }) - if options.Atomic { - if actualValue := pluginAPI.keyValues[key]; !bytes.Equal(actualValue, options.OldValue) { - return false, nil + t.Run("not-empty", func(t *testing.T) { + testCases := map[string]string{ + "key": mutexPrefix + "key", + "other": mutexPrefix + "other", } - } - - if value == nil { - delete(pluginAPI.keyValues, key) - } else { - pluginAPI.keyValues[key] = value - } - - return true, nil -} - -func (pluginAPI *MockMutexPluginAPI) LogError(msg string, keyValuePairs ...interface{}) { - if pluginAPI.t == nil { - return - } - pluginAPI.t.Helper() - - params := []interface{}{msg} - params = append(params, keyValuePairs...) - - pluginAPI.t.Log(params...) + for key, expected := range testCases { + actual, err := makeLockKey(key) + require.NoError(t, err) + assert.Equal(t, expected, actual) + } + }) } func lock(t *testing.T, m *Mutex) { @@ -94,6 +54,8 @@ func lock(t *testing.T, m *Mutex) { done := make(chan bool) go func() { + t.Helper() + defer close(done) m.Lock() }() @@ -110,6 +72,8 @@ func unlock(t *testing.T, m *Mutex, panics bool) { done := make(chan bool) go func() { + t.Helper() + defer close(done) if panics { assert.Panics(t, m.Unlock) @@ -126,10 +90,18 @@ func unlock(t *testing.T, m *Mutex, panics bool) { } func TestMutex(t *testing.T) { + t.Parallel() + + makeKey := func() string { + return model.NewId() + } + t.Run("successful lock/unlock cycle", func(t *testing.T) { - mockPluginAPI := NewMockMutexPluginAPI(t) + t.Parallel() + + mockPluginAPI := newMockPluginAPI(t) - m := NewMutex(mockPluginAPI, "key") + m := mustNewMutex(mockPluginAPI, makeKey()) lock(t, m) unlock(t, m, false) lock(t, m) @@ -137,16 +109,20 @@ func TestMutex(t *testing.T) { }) t.Run("unlock when not locked", func(t *testing.T) { - mockPluginAPI := NewMockMutexPluginAPI(t) + t.Parallel() - m := NewMutex(mockPluginAPI, "key") + mockPluginAPI := newMockPluginAPI(t) + + m := mustNewMutex(mockPluginAPI, makeKey()) unlock(t, m, true) }) t.Run("blocking lock", func(t *testing.T) { - mockPluginAPI := NewMockMutexPluginAPI(t) + t.Parallel() + + mockPluginAPI := newMockPluginAPI(t) - m := NewMutex(mockPluginAPI, "key") + m := mustNewMutex(mockPluginAPI, makeKey()) lock(t, m) done := make(chan bool) @@ -171,9 +147,11 @@ func TestMutex(t *testing.T) { }) t.Run("failed lock", func(t *testing.T) { - mockPluginAPI := NewMockMutexPluginAPI(t) + t.Parallel() + + mockPluginAPI := newMockPluginAPI(t) - m := NewMutex(mockPluginAPI, "key") + m := mustNewMutex(mockPluginAPI, makeKey()) mockPluginAPI.setFailing(true) @@ -199,9 +177,12 @@ func TestMutex(t *testing.T) { }) t.Run("failed unlock", func(t *testing.T) { - mockPluginAPI := NewMockMutexPluginAPI(t) + t.Parallel() - m := NewMutex(mockPluginAPI, "key") + mockPluginAPI := newMockPluginAPI(t) + + key := makeKey() + m := mustNewMutex(mockPluginAPI, key) lock(t, m) mockPluginAPI.setFailing(true) @@ -216,15 +197,17 @@ func TestMutex(t *testing.T) { }) t.Run("discrete keys", func(t *testing.T) { - mockPluginAPI := NewMockMutexPluginAPI(t) + t.Parallel() + + mockPluginAPI := newMockPluginAPI(t) - m1 := NewMutex(mockPluginAPI, "key1") + m1 := mustNewMutex(mockPluginAPI, makeKey()) lock(t, m1) - m2 := NewMutex(mockPluginAPI, "key2") + m2 := mustNewMutex(mockPluginAPI, makeKey()) lock(t, m2) - m3 := NewMutex(mockPluginAPI, "key3") + m3 := mustNewMutex(mockPluginAPI, makeKey()) lock(t, m3) unlock(t, m1, false) @@ -235,4 +218,69 @@ func TestMutex(t *testing.T) { unlock(t, m2, false) unlock(t, m1, false) }) + + t.Run("with uncancelled context", func(t *testing.T) { + t.Parallel() + + mockPluginAPI := newMockPluginAPI(t) + + key := makeKey() + m := mustNewMutex(mockPluginAPI, key) + + m.Lock() + + ctx := context.Background() + done := make(chan bool) + go func() { + defer close(done) + err := m.LockWithContext(ctx) + require.Nil(t, err) + }() + + select { + case <-time.After(ttl + pollWaitInterval*2): + case <-done: + require.Fail(t, "goroutine should not have locked") + } + + m.Unlock() + + select { + case <-time.After(pollWaitInterval * 2): + require.Fail(t, "goroutine should have locked after unlock") + case <-done: + } + }) + + t.Run("with cancelled context", func(t *testing.T) { + t.Parallel() + + mockPluginAPI := newMockPluginAPI(t) + + m := mustNewMutex(mockPluginAPI, makeKey()) + + m.Lock() + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan bool) + go func() { + defer close(done) + err := m.LockWithContext(ctx) + require.NotNil(t, err) + }() + + select { + case <-time.After(ttl + pollWaitInterval*2): + case <-done: + require.Fail(t, "goroutine should not have locked") + } + + cancel() + + select { + case <-time.After(pollWaitInterval * 2): + require.Fail(t, "goroutine should have aborted after cancellation") + case <-done: + } + }) } diff --git a/cluster/wait.go b/cluster/wait.go index bee8df9..bf62b4a 100644 --- a/cluster/wait.go +++ b/cluster/wait.go @@ -5,6 +5,20 @@ import ( "time" ) +const ( + // minWaitInterval is the minimum amount of time to wait between locking attempts + minWaitInterval = 1 * time.Second + + // maxWaitInterval is the maximum amount of time to wait between locking attempts + maxWaitInterval = 5 * time.Minute + + // pollWaitInterval is the usual time to wait between unsuccessful locking attempts + pollWaitInterval = 1 * time.Second + + // jitterWaitInterval is the amount of jitter to add when waiting to avoid thundering herds + jitterWaitInterval = minWaitInterval / 2 +) + // nextWaitInterval determines how long to wait until the next lock retry. func nextWaitInterval(lastWaitInterval time.Duration, err error) time.Duration { nextWaitInterval := lastWaitInterval