Skip to content

Commit

Permalink
Make read notify use sync.Cond instead of chan
Browse files Browse the repository at this point in the history
  • Loading branch information
edaniels committed Jul 22, 2024
1 parent d55a60c commit 4d3a2f6
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 31 deletions.
90 changes: 66 additions & 24 deletions packetio/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ type Buffer struct {
data []byte
head, tail int

notify chan struct{}
waiting bool
closed bool
closed bool

count int
limitCount, limitSize int

readDeadline *deadline.Deadline
readDeadline *deadline.Deadline
nextDeadline chan struct{}
readNotifier *sync.Cond
readChannelWatcherRunning sync.WaitGroup
}

const (
Expand All @@ -56,9 +57,35 @@ const (

// NewBuffer creates a new Buffer.
func NewBuffer() *Buffer {
return &Buffer{
notify: make(chan struct{}, 1),
buffer := &Buffer{
readDeadline: deadline.New(),
nextDeadline: make(chan struct{}, 1),
}
buffer.readNotifier = sync.NewCond(&buffer.mutex)
buffer.readChannelWatcherRunning.Add(1)
go buffer.readDeadlineWatcher()
return buffer
}

func (b *Buffer) readDeadlineWatcher() {
defer b.readChannelWatcherRunning.Done()
for {
select {
case <-b.readDeadline.Done():
b.mutex.Lock()
b.readNotifier.Broadcast()
b.mutex.Unlock()
case _, ok := <-b.nextDeadline:
if ok {
continue
}
return
}

_, ok := <-b.nextDeadline
if !ok {
return
}
}
}

Expand Down Expand Up @@ -173,15 +200,7 @@ func (b *Buffer) Write(packet []byte) (int, error) {
}
b.count++

waiting := b.waiting
b.waiting = false

if waiting {
select {
case b.notify <- struct{}{}:
default:
}
}
b.readNotifier.Signal()
b.mutex.Unlock()

return len(packet), nil
Expand All @@ -199,9 +218,8 @@ func (b *Buffer) Read(packet []byte) (n int, err error) { //nolint:gocognit
default:
}

b.mutex.Lock()
for {
b.mutex.Lock()

if b.head != b.tail {
// decode the packet size
n1 := b.data[b.head]
Expand Down Expand Up @@ -244,7 +262,6 @@ func (b *Buffer) Read(packet []byte) (n int, err error) { //nolint:gocognit
}

b.count--
b.waiting = false
b.mutex.Unlock()

if copied < count {
Expand All @@ -258,32 +275,47 @@ func (b *Buffer) Read(packet []byte) (n int, err error) { //nolint:gocognit
return 0, io.EOF
}

b.waiting = true
b.mutex.Unlock()

b.readNotifier.Wait()
select {
case <-b.readDeadline.Done():
b.mutex.Unlock()
return 0, &netError{ErrTimeout, true, true}
case <-b.notify:
default:
}
}
}

// Close the buffer, unblocking any pending reads.
// Data in the buffer can still be read, Read will return io.EOF only when empty.
func (b *Buffer) Close() (err error) {
return b.close(false)
}

// GracefulClose closes the buffer, unblocking any pending reads.
// Data in the buffer can still be read, Read will return io.EOF only when empty.
// It returns when any goroutines Buffer started have completed. This should not be called
// in any callbacks that may own a buffer unless a goroutine is spawned in that callback
// to call GracefulClose.
func (b *Buffer) GracefulClose() (err error) {
return b.close(true)
}

func (b *Buffer) close(graceful bool) error {
b.mutex.Lock()

if b.closed {
b.mutex.Unlock()
return nil
}

b.waiting = false
b.closed = true
close(b.notify)
close(b.nextDeadline)
b.readNotifier.Broadcast()
b.mutex.Unlock()

if graceful {
b.readChannelWatcherRunning.Wait()
}
return nil
}

Expand Down Expand Up @@ -338,6 +370,16 @@ func (b *Buffer) SetLimitSize(limit int) {
// SetReadDeadline sets the deadline for the Read operation.
// Setting to zero means no deadline.
func (b *Buffer) SetReadDeadline(t time.Time) error {
b.mutex.Lock()
defer b.mutex.Unlock()

b.readDeadline.Set(t)
select {
case b.nextDeadline <- struct{}{}:
default:
// if there is no receiver, then we know that readDeadlineWatcher
// is about to receive the buffered value in the channel. otherwise
// we communicated the next deadline to the receiver directly.
}
return nil
}
75 changes: 68 additions & 7 deletions packetio/buffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"testing"
"time"

"github.com/pion/transport/v3/test"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -430,10 +431,10 @@ func TestBufferAlloc(t *testing.T) {
}
}

t.Run("100 writes", test(w, 100, 11))
t.Run("200 writes", test(w, 200, 14))
t.Run("400 writes", test(w, 400, 17))
t.Run("1000 writes", test(w, 1000, 21))
t.Run("100 writes", test(w, 100, 14))
t.Run("200 writes", test(w, 200, 20))
t.Run("400 writes", test(w, 400, 20))
t.Run("1000 writes", test(w, 1000, 26))

wr := func(count int) func() {
return func() {
Expand All @@ -451,9 +452,9 @@ func TestBufferAlloc(t *testing.T) {
}
}

t.Run("100 writes and reads", test(wr, 100, 5))
t.Run("1000 writes and reads", test(wr, 1000, 5))
t.Run("10000 writes and reads", test(wr, 10000, 5))
t.Run("100 writes and reads", test(wr, 100, 18))
t.Run("1000 writes and reads", test(wr, 1000, 18))
t.Run("10000 writes and reads", test(wr, 10000, 18))
}

func benchmarkBufferWR(b *testing.B, size int64, write bool, grow int) { // nolint:unparam
Expand Down Expand Up @@ -584,6 +585,8 @@ func BenchmarkBuffer1400(b *testing.B) {
}

func TestBufferConcurrentRead(t *testing.T) {
defer test.TimeOut(time.Second * 5).Stop()

assert := assert.New(t)

buffer := NewBuffer()
Expand Down Expand Up @@ -626,3 +629,61 @@ func TestBufferConcurrentRead(t *testing.T) {
err = <-errCh
assert.Equal(io.EOF, err)
}

func TestBufferConcurrentReadWrite(t *testing.T) {
defer test.TimeOut(time.Second * 5).Stop()

assert := assert.New(t)

buffer := NewBuffer()
packet := make([]byte, 4)

errCh := make(chan error, 4)
readIntoErr := func() {
_, readErr := buffer.Read(packet)
errCh <- readErr
}
writeIntoErr := func() {
_, writeErr := buffer.Write([]byte{2, 3, 4})
errCh <- writeErr
}
go readIntoErr()
go readIntoErr()
go writeIntoErr()
go writeIntoErr()

// Close
err := buffer.Close()
assert.NoError(err)

// we just care that the reads and writes happen
for i := 0; i < 4; i++ {
<-errCh
}
}

func TestBufferReadDeadlineInSyncCond(t *testing.T) {
defer test.TimeOut(time.Second * 10).Stop()

assert := assert.New(t)

buffer := NewBuffer()

assert.NoError(buffer.SetReadDeadline(time.Now().Add(5 * time.Second))) // Set deadline to avoid test deadlock

// Start up a goroutine to start a blocking read.
readErr := make(chan error)
go func() {
packet := make([]byte, 4)
_, err := buffer.Read(packet)
readErr <- err
}()

err := <-readErr
var e net.Error
if !errors.As(err, &e) || !e.Timeout() {
t.Errorf("Unexpected error: %v", err)
}

assert.NoError(buffer.GracefulClose())
}

0 comments on commit 4d3a2f6

Please sign in to comment.