Skip to content

Commit

Permalink
Cancelable reads (#120)
Browse files Browse the repository at this point in the history
This commit implements cancelable reads, which allows Bubble Tea programs to run in succession in a single application. It also makes sure all goroutines terminate before `Program.Start()` returns.

Closes #24.
  • Loading branch information
erikgeiser authored Sep 28, 2021
1 parent 7396e37 commit e402e8b
Show file tree
Hide file tree
Showing 11 changed files with 906 additions and 26 deletions.
72 changes: 72 additions & 0 deletions cancelreader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package tea

import (
"fmt"
"io"
"sync"
)

var errCanceled = fmt.Errorf("read cancelled")

// cancelReader is a io.Reader whose Read() calls can be cancelled without data
// being consumed. The cancelReader has to be closed.
type cancelReader interface {
io.ReadCloser

// Cancel cancels ongoing and future reads an returns true if it succeeded.
Cancel() bool
}

// fallbackCancelReader implements cancelReader but does not actually support
// cancelation during an ongoing Read() call. Thus, Cancel() always returns
// false. However, after calling Cancel(), new Read() calls immediately return
// errCanceled and don't consume any data anymore.
type fallbackCancelReader struct {
r io.Reader
cancelled bool
}

// newFallbackCancelReader is a fallback for newCancelReader that cannot
// actually cancel an ongoing read but will immediately return on future reads
// if it has been cancelled.
func newFallbackCancelReader(reader io.Reader) (cancelReader, error) {
return &fallbackCancelReader{r: reader}, nil
}

func (r *fallbackCancelReader) Read(data []byte) (int, error) {
if r.cancelled {
return 0, errCanceled
}

return r.r.Read(data)
}

func (r *fallbackCancelReader) Cancel() bool {
r.cancelled = true

return false
}

func (r *fallbackCancelReader) Close() error {
return nil
}

// cancelMixin represents a goroutine-safe cancelation status.
type cancelMixin struct {
unsafeCancelled bool
lock sync.Mutex
}

func (c *cancelMixin) isCancelled() bool {
c.lock.Lock()
defer c.lock.Unlock()

return c.unsafeCancelled
}

func (c *cancelMixin) setCancelled() {
c.lock.Lock()
defer c.lock.Unlock()

c.unsafeCancelled = true
}
144 changes: 144 additions & 0 deletions cancelreader_bsd.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// +build darwin freebsd netbsd openbsd

// nolint:revive
package tea

import (
"errors"
"fmt"
"io"
"os"
"strings"

"golang.org/x/sys/unix"
)

// newkqueueCancelReader returns a reader and a cancel function. If the input reader
// is an *os.File, the cancel function can be used to interrupt a blocking call
// read call. In this case, the cancel function returns true if the call was
// cancelled successfully. If the input reader is not a *os.File, the cancel
// function does nothing and always returns false. The BSD and macOS
// implementation is based on the kqueue mechanism.
func newCancelReader(reader io.Reader) (cancelReader, error) {
file, ok := reader.(*os.File)
if !ok {
return newFallbackCancelReader(reader)
}

// kqueue returns instantly when polling /dev/tty so fallback to select
if file.Name() == "/dev/tty" {
return newSelectCancelReader(reader)
}

kQueue, err := unix.Kqueue()
if err != nil {
return nil, fmt.Errorf("create kqueue: %w", err)
}

r := &kqueueCancelReader{
file: file,
kQueue: kQueue,
}

r.cancelSignalReader, r.cancelSignalWriter, err = os.Pipe()
if err != nil {
return nil, err
}

unix.SetKevent(&r.kQueueEvents[0], int(file.Fd()), unix.EVFILT_READ, unix.EV_ADD)
unix.SetKevent(&r.kQueueEvents[1], int(r.cancelSignalReader.Fd()), unix.EVFILT_READ, unix.EV_ADD)

return r, nil
}

type kqueueCancelReader struct {
file *os.File
cancelSignalReader *os.File
cancelSignalWriter *os.File
cancelMixin
kQueue int
kQueueEvents [2]unix.Kevent_t
}

func (r *kqueueCancelReader) Read(data []byte) (int, error) {
if r.isCancelled() {
return 0, errCanceled
}

err := r.wait()
if err != nil {
if errors.Is(err, errCanceled) {
// remove signal from pipe
var b [1]byte
_, errRead := r.cancelSignalReader.Read(b[:])
if errRead != nil {
return 0, fmt.Errorf("reading cancel signal: %w", errRead)
}
}

return 0, err
}

return r.file.Read(data)
}

func (r *kqueueCancelReader) Cancel() bool {
r.setCancelled()

// send cancel signal
_, err := r.cancelSignalWriter.Write([]byte{'c'})
return err == nil
}

func (r *kqueueCancelReader) Close() error {
var errMsgs []string

// close kqueue
err := unix.Close(r.kQueue)
if err != nil {
errMsgs = append(errMsgs, fmt.Sprintf("closing kqueue: %v", err))
}

// close pipe
err = r.cancelSignalWriter.Close()
if err != nil {
errMsgs = append(errMsgs, fmt.Sprintf("closing cancel signal writer: %v", err))
}

err = r.cancelSignalReader.Close()
if err != nil {
errMsgs = append(errMsgs, fmt.Sprintf("closing cancel signal reader: %v", err))
}

if len(errMsgs) > 0 {
return fmt.Errorf(strings.Join(errMsgs, ", "))
}

return nil
}

func (r *kqueueCancelReader) wait() error {
events := make([]unix.Kevent_t, 1)

for {
_, err := unix.Kevent(r.kQueue, r.kQueueEvents[:], events, nil)
if errors.Is(err, unix.EINTR) {
continue // try again if the syscall was interrupted
}

if err != nil {
return fmt.Errorf("kevent: %w", err)
}

break
}

switch events[0].Ident {
case uint64(r.file.Fd()):
return nil
case uint64(r.cancelSignalReader.Fd()):
return errCanceled
}

return fmt.Errorf("unknown error")
}
14 changes: 14 additions & 0 deletions cancelreader_default.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//go:build !darwin && !windows && !linux && !solaris && !freebsd && !netbsd && !openbsd
// +build !darwin,!windows,!linux,!solaris,!freebsd,!netbsd,!openbsd

package tea

import (
"io"
)

// newCancelReader returns a fallbackCancelReader that satisfies the
// cancelReader but does not actually support cancelation.
func newCancelReader(reader io.Reader) (cancelReader, error) {
return newFallbackCancelReader(reader)
}
152 changes: 152 additions & 0 deletions cancelreader_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
//go:build linux
// +build linux

// nolint:revive
package tea

import (
"errors"
"fmt"
"io"
"os"
"strings"

"golang.org/x/sys/unix"
)

// newCancelReader returns a reader and a cancel function. If the input reader
// is an *os.File, the cancel function can be used to interrupt a blocking call
// read call. In this case, the cancel function returns true if the call was
// cancelled successfully. If the input reader is not a *os.File, the cancel
// function does nothing and always returns false. The linux implementation is
// based on the epoll mechanism.
func newCancelReader(reader io.Reader) (cancelReader, error) {
file, ok := reader.(*os.File)
if !ok {
return newFallbackCancelReader(reader)
}

epoll, err := unix.EpollCreate1(0)
if err != nil {
return nil, fmt.Errorf("create epoll: %w", err)
}

r := &epollCancelReader{
file: file,
epoll: epoll,
}

r.cancelSignalReader, r.cancelSignalWriter, err = os.Pipe()
if err != nil {
return nil, err
}

err = unix.EpollCtl(epoll, unix.EPOLL_CTL_ADD, int(file.Fd()), &unix.EpollEvent{
Events: unix.EPOLLIN,
Fd: int32(file.Fd()),
})
if err != nil {
return nil, fmt.Errorf("add reader to epoll interrest list")
}

err = unix.EpollCtl(epoll, unix.EPOLL_CTL_ADD, int(r.cancelSignalReader.Fd()), &unix.EpollEvent{
Events: unix.EPOLLIN,
Fd: int32(r.cancelSignalReader.Fd()),
})
if err != nil {
return nil, fmt.Errorf("add reader to epoll interrest list")
}

return r, nil
}

type epollCancelReader struct {
file *os.File
cancelSignalReader *os.File
cancelSignalWriter *os.File
cancelMixin
epoll int
}

func (r *epollCancelReader) Read(data []byte) (int, error) {
if r.isCancelled() {
return 0, errCanceled
}

err := r.wait()
if err != nil {
if errors.Is(err, errCanceled) {
// remove signal from pipe
var b [1]byte
_, readErr := r.cancelSignalReader.Read(b[:])
if readErr != nil {
return 0, fmt.Errorf("reading cancel signal: %w", readErr)
}
}

return 0, err
}

return r.file.Read(data)
}

func (r *epollCancelReader) Cancel() bool {
r.setCancelled()

// send cancel signal
_, err := r.cancelSignalWriter.Write([]byte{'c'})
return err == nil
}

func (r *epollCancelReader) Close() error {
var errMsgs []string

// close kqueue
err := unix.Close(r.epoll)
if err != nil {
errMsgs = append(errMsgs, fmt.Sprintf("closing epoll: %v", err))
}

// close pipe
err = r.cancelSignalWriter.Close()
if err != nil {
errMsgs = append(errMsgs, fmt.Sprintf("closing cancel signal writer: %v", err))
}

err = r.cancelSignalReader.Close()
if err != nil {
errMsgs = append(errMsgs, fmt.Sprintf("closing cancel signal reader: %v", err))
}

if len(errMsgs) > 0 {
return fmt.Errorf(strings.Join(errMsgs, ", "))
}

return nil
}

func (r *epollCancelReader) wait() error {
events := make([]unix.EpollEvent, 1)

for {
_, err := unix.EpollWait(r.epoll, events, -1)
if errors.Is(err, unix.EINTR) {
continue // try again if the syscall was interrupted
}

if err != nil {
return fmt.Errorf("kevent: %w", err)
}

break
}

switch events[0].Fd {
case int32(r.file.Fd()):
return nil
case int32(r.cancelSignalReader.Fd()):
return errCanceled
}

return fmt.Errorf("unknown error")
}
Loading

0 comments on commit e402e8b

Please sign in to comment.