Skip to content

Commit

Permalink
better return type
Browse files Browse the repository at this point in the history
  • Loading branch information
Francesco Cosentino committed Jan 19, 2023
1 parent b5baa8c commit 1342787
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 56 deletions.
24 changes: 18 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@
`go-again` **thread safely** wraps a given function and executes it until it returns a nil error or exceeds the maximum number of retries.
The configuration consists of the maximum number of retries, the interval, a jitter to add a randomized backoff, the timeout, and a registry to store errors that you consider temporary, hence worth a retry.
The `Retry` method takes a context, a function, and an optional list of `temporary errors` as arguments. It supports cancellation from the context and a channel invoking the `Cancel()` function.
The returned type is `RetryErrors` which contains the list of errors returned at each attempt and the last error returned by the function.

```golang
// RetryErrors holds the error returned by the retry function along with the trace of each attempt.
type RetryErrors struct {
// Retries hold the trace of each attempt.
Retries map[int]error
// ExitError holds the last error returned by the retry function.
ExitError error
}
```

The registry only allows you to retry a function if it returns a registered error:

```go
Expand All @@ -17,15 +29,15 @@ The registry only allows you to retry a function if it returns a registered erro

defer retrier.Registry.UnRegisterTemporaryError("http.ErrAbortHandler")
var retryCount int
err := retrier.Retry(context.TODO(), func() error {
errs := retrier.Retry(context.TODO(), func() error {
retryCount++
if retryCount < 3 {
return http.ErrAbortHandler
}
return nil
}, "http.ErrAbortHandler")

if err != nil {
if errs.ExitError != nil {
// handle error
}
```
Expand All @@ -37,14 +49,14 @@ Should you retry regardless of the error returned, that's easy. It's enough call
retrier := again.NewRetrier(again.WithTimeout(1*time.Second),
again.WithJitter(500*time.Millisecond),
again.WithMaxRetries(3))
err := retrier.Retry(context.TODO(), func() error {
errs := retrier.Retry(context.TODO(), func() error {
retryCount++
if retryCount < 3 {
return http.ErrAbortHandler
}
return nil
})
if err != nil {
if errs.ExitError != nil {
// handle error
}
```
Expand Down Expand Up @@ -109,11 +121,11 @@ func main() {
})

// Retry a function.
err := retrier.Retry(context.TODO(), func() error {
errs := retrier.Retry(context.TODO(), func() error {
// Do something here.
return fmt.Errorf("temporary error")
}, "temporary error")
if err != nil {
if errs.ExitError != nil {
fmt.Println(err)
}
}
Expand Down
8 changes: 5 additions & 3 deletions examples/chan/chan.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ func main() {
}

// Retry the function.
err := retrier.Retry(ctx, fn)
if err != nil {
fmt.Println(err)
errs := retrier.Retry(ctx, fn)
if errs.ExitError != nil {
fmt.Println(errs)
} else {
fmt.Println("success")
}
}
9 changes: 6 additions & 3 deletions examples/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@ func main() {
}

// Retry the function.
err := retrier.Retry(ctx, fn)
if err != nil {
fmt.Println(err)
errs := retrier.Retry(ctx, fn)

if errs.ExitError != nil {
fmt.Println(errs)
} else {
fmt.Println("success")
}
}
84 changes: 49 additions & 35 deletions retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,31 @@ package again
import (
"context"
"errors"
"fmt"
"math/rand"
"strconv"
"sync"
"time"
)

var (
// retryErrorPool is a pool of RetryError objects.
retryErrorPool = sync.Pool{
// retryErrorsPool is a pool of RetryErrors objects.
retryErrorsPool = sync.Pool{
New: func() interface{} {
return &RetryError{}
return &RetryErrors{
Retries: make(map[int]error),
}
},
}
)

// RetryableFunc signature of retryable function
type RetryableFunc func() error

// RetryError is an error returned by the Retry function when the maximum number of retries is reached.
type RetryError struct {
MaxRetries int
Err error
}

// Error returns the error message.
func (e *RetryError) Error() string {
return "maximum number of retries (" + strconv.Itoa(e.MaxRetries) + ") reached: " + e.Err.Error()
}

// Unwrap returns the underlying error.
func (e *RetryError) Unwrap() error {
return e.Err
// RetryErrors holds the error returned by the retry function along with the trace of each attempt.
type RetryErrors struct {
// Retries holds the trace of each attempt.
Retries map[int]error
// ExitError holds the last error returned by the retry function.
ExitError error
}

// Retrier is a type that retries a function until it returns a nil error or the maximum number of retries is reached.
Expand Down Expand Up @@ -94,47 +86,70 @@ func NewRetrier(opts ...Option) *Retrier {
}

// SetRegistry sets the registry for temporary errors.
func (r *Retrier) SetRegistry(reg *registry) {
// Use this function to set a custom registry if:
// - you want to add custom temporary errors.
// - you want to remove the default temporary errors.
// - you want to replace the default temporary errors with your own.
// - you have initialized the Retrier without using the constructor `NewRetrier`.
func (r *Retrier) SetRegistry(reg *registry) error {
// set the registry if not nil.
if reg == nil {
return errors.New("registry cannot be nil")
}
// set the registry if not already set.
r.once.Do(func() {
r.Registry = reg
})
return nil
}

// Retry retries a `retryableFunc` until it returns a nil error or the maximum number of retries is reached.
// - If the maximum number of retries is reached, the function returns a `RetryError` object.
// - If the `retryableFunc` returns a nil error, the function returns nil.
// - If the `retryableFunc` returns a nil error, the function assigns a `RetryErrors.ExitError` before returning.
// - If the `retryableFunc` returns a temporary error, the function retries the function.
// - If the `retryableFunc` returns a non-temporary error, the function returns the error.
// - If the `retryableFunc` returns a non-temporary error, the function assigns the error to `RetryErrors.ExitError` and returns.
// - If the `temporaryErrors` list is empty, the function retries the function until the maximum number of retries is reached.
// - The context is used to cancel the retries, or set a deadline if the `retryableFunc` hangs.
func (r *Retrier) Retry(ctx context.Context, retryableFunc RetryableFunc, temporaryErrors ...string) error {
func (r *Retrier) Retry(ctx context.Context, retryableFunc RetryableFunc, temporaryErrors ...string) (errs *RetryErrors) {
// lock the mutex to synchronize access to the timer.
r.mutex.RLock()
defer r.mutex.RUnlock()

// get a new RetryErrors object from the pool.
errs = retryErrorsPool.Get().(*RetryErrors)
defer retryErrorsPool.Put(errs)

// If the maximum number of retries is 0, call the function once and return the result.
if r.MaxRetries == 0 {
return retryableFunc()
errs.ExitError = retryableFunc()
return
}

// Check for invalid inputs.
if retryableFunc == nil {
return errors.New("failed to invoke the function. It appears to be is nil")
errs.ExitError = errors.New("failed to invoke the function. It appears to be is nil")
return
}

// `rng` is the random number generator used to apply jitter to the retry interval.
rng := rand.New(rand.NewSource(time.Now().UnixNano()))

// Retry the function until it returns a nil error or the maximum number of retries is reached.
for i := 0; i < r.MaxRetries; i++ {
for attempt := 0; attempt < r.MaxRetries; attempt++ {
// Increment the attempt.

// Call the function to retry.
r.err = retryableFunc()

// If the function returns a nil error, return nil.
if r.err == nil {
return nil
errs.ExitError = nil
return
}

// Set the error returned by the function.
errs.Retries[attempt] = r.err

// Check if the error returned by the function is temporary when the list of temporary errors is not empty.
if len(temporaryErrors) > 0 && !r.IsTemporaryError(r.err, temporaryErrors...) {
break
Expand All @@ -143,7 +158,8 @@ func (r *Retrier) Retry(ctx context.Context, retryableFunc RetryableFunc, tempor
// Check if the context is cancelled.
select {
case <-ctx.Done():
return ctx.Err()
errs.ExitError = ctx.Err()
return
default:
}

Expand All @@ -156,20 +172,18 @@ func (r *Retrier) Retry(ctx context.Context, retryableFunc RetryableFunc, tempor
select {
case <-r.cancel:
r.timer.put(timer)
return fmt.Errorf("retries cancelled")
errs.ExitError = errors.New("retries cancelled")
return
case <-r.stop:
r.timer.put(timer)
return r.err
errs.ExitError = errors.New("retries stopped")
return
case <-timer.C:
r.timer.put(timer)
}
}

// Return an error indicating that the maximum number of retries was reached.
retryErr := retryErrorPool.Get().(*RetryError)
retryErr.MaxRetries = r.MaxRetries
retryErr.Err = r.err
return retryErr
return
}

// Cancel cancels the retries notifying the `Retry` function to return.
Expand Down
18 changes: 9 additions & 9 deletions retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ func TestRetry(t *testing.T) {

defer retrier.Registry.UnRegisterTemporaryError("http.ErrAbortHandler")

err := retrier.Retry(context.TODO(), func() error {
errs := retrier.Retry(context.TODO(), func() error {
retryCount++
if retryCount < 3 {
return http.ErrAbortHandler
}
return nil
}, "http.ErrAbortHandler")

if err != nil {
t.Errorf("retry returned an unexpected error: %v", err)
if errs.ExitError != nil {
t.Errorf("retry returned an unexpected error: %v", errs.ExitError)
}
if retryCount != 3 {
t.Errorf("retry did not retry the function the expected number of times. Got: %d, Expecting: %d", retryCount, 3)
Expand All @@ -38,16 +38,16 @@ func TestWithoutRegistry(t *testing.T) {
var retryCount int
retrier := NewRetrier()

err := retrier.Retry(context.TODO(), func() error {
errs := retrier.Retry(context.TODO(), func() error {
retryCount++
if retryCount < 3 {
return http.ErrAbortHandler
}
return nil
})

if err != nil {
t.Errorf("retry returned an unexpected error: %v", err)
if errs.ExitError != nil {
t.Errorf("retry returned an unexpected error: %v", errs.ExitError)
}
if retryCount != 3 {
t.Errorf("retry did not retry the function the expected number of times. Got: %d, Expecting: %d", retryCount, 1)
Expand All @@ -62,16 +62,16 @@ func TestRetryWithDefaults(t *testing.T) {

defer retrier.Registry.Clean()

err := retrier.Retry(context.TODO(), func() error {
errs := retrier.Retry(context.TODO(), func() error {
retryCount++
if retryCount < 3 {
return http.ErrHandlerTimeout
}
return nil
}, "http.ErrHandlerTimeout")

if err != nil {
t.Errorf("retry returned an unexpected error: %v", err)
if errs.ExitError != nil {
t.Errorf("retry returned an unexpected error: %v", errs.ExitError)
}
if retryCount != 3 {
t.Errorf("retry did not retry the function the expected number of times. Got: %d, Expecting: %d", retryCount, 3)
Expand Down

0 comments on commit 1342787

Please sign in to comment.