diff --git a/fake.go b/fake.go index 77c27e8..4e87713 100644 --- a/fake.go +++ b/fake.go @@ -23,20 +23,27 @@ import ( "regexp" "sort" "strings" + "sync" ) // Fake is a fake implementation of Interface type Fake struct { nftContext + // mutex is used to protect Table and LastTransaction. + // When Table and LastTransaction are accessed directly, the caller must acquire Fake.RLock + // and release when finished. + sync.RWMutex nextHandle int // Table contains the Interface's table. This will be `nil` until you `tx.Add()` // the table. + // Make sure to acquire Fake.RLock before accessing Table in a concurrent environment. Table *FakeTable // LastTransaction is the last transaction passed to Run(). It will remain set until the // next time Run() is called. (It is not affected by Check().) + // Make sure to acquire Fake.RLock before accessing LastTransaction in a concurrent environment. LastTransaction *Transaction } @@ -94,6 +101,8 @@ var _ Interface = &Fake{} // List is part of Interface. func (fake *Fake) List(_ context.Context, objectType string) ([]string, error) { + fake.RLock() + defer fake.RUnlock() if fake.Table == nil { return nil, notFoundError("no such table %q", fake.table) } @@ -123,6 +132,8 @@ func (fake *Fake) List(_ context.Context, objectType string) ([]string, error) { // ListRules is part of Interface func (fake *Fake) ListRules(_ context.Context, chain string) ([]*Rule, error) { + fake.RLock() + defer fake.RUnlock() if fake.Table == nil { return nil, notFoundError("no such table %q", fake.table) } @@ -145,6 +156,8 @@ func (fake *Fake) ListRules(_ context.Context, chain string) ([]*Rule, error) { // ListElements is part of Interface func (fake *Fake) ListElements(_ context.Context, objectType, name string) ([]*Element, error) { + fake.RLock() + defer fake.RUnlock() if fake.Table == nil { return nil, notFoundError("no such %s %q", objectType, name) } @@ -169,6 +182,8 @@ func (fake *Fake) NewTransaction() *Transaction { // Run is part of Interface func (fake *Fake) Run(_ context.Context, tx *Transaction) error { + fake.Lock() + defer fake.Unlock() fake.LastTransaction = tx updatedTable, err := fake.run(tx) if err == nil { @@ -179,10 +194,13 @@ func (fake *Fake) Run(_ context.Context, tx *Transaction) error { // Check is part of Interface func (fake *Fake) Check(_ context.Context, tx *Transaction) error { + fake.RLock() + defer fake.RUnlock() _, err := fake.run(tx) return err } +// must be called with fake.lock held func (fake *Fake) run(tx *Transaction) (*FakeTable, error) { if tx.err != nil { return nil, tx.err @@ -480,6 +498,8 @@ func checkElementRefs(element *Element, table *FakeTable) error { // Dump dumps the current contents of fake, in a way that looks like an nft transaction. func (fake *Fake) Dump() string { + fake.RLock() + defer fake.RUnlock() if fake.Table == nil { return "" }