Skip to content

Commit

Permalink
Allow concurrent calls for Fake methods.
Browse files Browse the repository at this point in the history
Signed-off-by: Nadia Pinaeva <n.m.pinaeva@gmail.com>
  • Loading branch information
npinaeva committed Sep 2, 2024
1 parent 19fb4da commit bb730df
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,16 @@ import (
"regexp"
"sort"
"strings"
"sync"
)

// Fake is a fake implementation of Interface
type Fake struct {
nftContext
// lock is used to protect Table and LastTransaction.
// When Table and LastTransaction are accessed directly, it is caller's responsibility
// to ensure there is no data race.
lock sync.RWMutex

nextHandle int

Expand Down Expand Up @@ -94,6 +99,8 @@ var _ Interface = &Fake{}

// List is part of Interface.
func (fake *Fake) List(_ context.Context, objectType string) ([]string, error) {
fake.lock.RLock()
defer fake.lock.RUnlock()
if fake.Table == nil {
return nil, notFoundError("no such table %q", fake.table)
}
Expand Down Expand Up @@ -123,6 +130,8 @@ func (fake *Fake) List(_ context.Context, objectType string) ([]string, error) {

// ListRules is part of Interface
func (fake *Fake) ListRules(_ context.Context, chain string) ([]*Rule, error) {
fake.lock.RLock()
defer fake.lock.RUnlock()
if fake.Table == nil {
return nil, notFoundError("no such table %q", fake.table)
}
Expand All @@ -145,6 +154,8 @@ func (fake *Fake) ListRules(_ context.Context, chain string) ([]*Rule, error) {

// ListElements is part of Interface
func (fake *Fake) ListElements(_ context.Context, objectType, name string) ([]*Element, error) {
fake.lock.RLock()
defer fake.lock.RUnlock()
if fake.Table == nil {
return nil, notFoundError("no such %s %q", objectType, name)
}
Expand All @@ -169,6 +180,8 @@ func (fake *Fake) NewTransaction() *Transaction {

// Run is part of Interface
func (fake *Fake) Run(_ context.Context, tx *Transaction) error {
fake.lock.Lock()
defer fake.lock.Unlock()
fake.LastTransaction = tx
updatedTable, err := fake.run(tx)
if err == nil {
Expand All @@ -179,10 +192,13 @@ func (fake *Fake) Run(_ context.Context, tx *Transaction) error {

// Check is part of Interface
func (fake *Fake) Check(_ context.Context, tx *Transaction) error {
fake.lock.RLock()
defer fake.lock.RUnlock()
_, err := fake.run(tx)
return err
}

// must be called with fake.lock held
func (fake *Fake) run(tx *Transaction) (*FakeTable, error) {
if tx.err != nil {
return nil, tx.err
Expand Down Expand Up @@ -480,6 +496,8 @@ func checkElementRefs(element *Element, table *FakeTable) error {

// Dump dumps the current contents of fake, in a way that looks like an nft transaction.
func (fake *Fake) Dump() string {
fake.lock.RLock()
defer fake.lock.RUnlock()
if fake.Table == nil {
return ""
}
Expand Down

0 comments on commit bb730df

Please sign in to comment.