diff --git a/fake.go b/fake.go index 77c27e8..e9a3787 100644 --- a/fake.go +++ b/fake.go @@ -23,11 +23,16 @@ import ( "regexp" "sort" "strings" + "sync" ) // Fake is a fake implementation of Interface type Fake struct { nftContext + // lock is used to protect Table and LastTransaction. + // When Table and LastTransaction are accessed directly, it is caller's responsibility + // to ensure there is no data race. + lock sync.RWMutex nextHandle int @@ -94,6 +99,8 @@ var _ Interface = &Fake{} // List is part of Interface. func (fake *Fake) List(_ context.Context, objectType string) ([]string, error) { + fake.lock.RLock() + defer fake.lock.RUnlock() if fake.Table == nil { return nil, notFoundError("no such table %q", fake.table) } @@ -123,6 +130,8 @@ func (fake *Fake) List(_ context.Context, objectType string) ([]string, error) { // ListRules is part of Interface func (fake *Fake) ListRules(_ context.Context, chain string) ([]*Rule, error) { + fake.lock.RLock() + defer fake.lock.RUnlock() if fake.Table == nil { return nil, notFoundError("no such table %q", fake.table) } @@ -145,6 +154,8 @@ func (fake *Fake) ListRules(_ context.Context, chain string) ([]*Rule, error) { // ListElements is part of Interface func (fake *Fake) ListElements(_ context.Context, objectType, name string) ([]*Element, error) { + fake.lock.RLock() + defer fake.lock.RUnlock() if fake.Table == nil { return nil, notFoundError("no such %s %q", objectType, name) } @@ -169,6 +180,8 @@ func (fake *Fake) NewTransaction() *Transaction { // Run is part of Interface func (fake *Fake) Run(_ context.Context, tx *Transaction) error { + fake.lock.Lock() + defer fake.lock.Unlock() fake.LastTransaction = tx updatedTable, err := fake.run(tx) if err == nil { @@ -179,10 +192,13 @@ func (fake *Fake) Run(_ context.Context, tx *Transaction) error { // Check is part of Interface func (fake *Fake) Check(_ context.Context, tx *Transaction) error { + fake.lock.RLock() + defer fake.lock.RUnlock() _, err := fake.run(tx) return err } +// must be called with fake.lock held func (fake *Fake) run(tx *Transaction) (*FakeTable, error) { if tx.err != nil { return nil, tx.err @@ -480,6 +496,8 @@ func checkElementRefs(element *Element, table *FakeTable) error { // Dump dumps the current contents of fake, in a way that looks like an nft transaction. func (fake *Fake) Dump() string { + fake.lock.RLock() + defer fake.lock.RUnlock() if fake.Table == nil { return "" }