diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index b6b608d..90acbb5 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -17,7 +17,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v2 with: - go-version: 1.21 + go-version: 1.23 - name: Check License Headers run: make license-header-check - name: Run Tests diff --git a/go.mod b/go.mod index 08738ab..586c622 100644 --- a/go.mod +++ b/go.mod @@ -220,5 +220,4 @@ require ( go-simpler.org/musttag v0.12.2 // indirect go-simpler.org/sloglint v0.7.2 // indirect go.uber.org/automaxprocs v1.5.3 // indirect - go.uber.org/goleak v1.3.0 // indirect ) diff --git a/go.sum b/go.sum index 30fcc35..b77690e 100644 --- a/go.sum +++ b/go.sum @@ -147,8 +147,8 @@ github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4 github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= github.com/firefart/nonamedreturns v1.0.5 h1:tM+Me2ZaXs8tfdDw3X6DOX++wMCOqzYUho6tUTYIdRA= github.com/firefart/nonamedreturns v1.0.5/go.mod h1:gHJjDqhGM4WyPt639SOZs+G89Ko7QKH5R5BhnO6xJhw= -github.com/frankban/quicktest v1.14.3 h1:FJKSZTDHjyhriyC81FLQ0LY93eSai0ZyR/ZIkd3ZUKE= -github.com/frankban/quicktest v1.14.3/go.mod h1:mgiwOwqx65TmIk1wJ6Q7wvnVMocbUorkibMOrVTHZps= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= @@ -424,8 +424,6 @@ github.com/nishanths/exhaustive v0.12.0 h1:vIY9sALmw6T/yxiASewa4TQcFsVYZQQRUQJhK github.com/nishanths/exhaustive v0.12.0/go.mod h1:mEZ95wPIZW+x8kC4TgC+9YCUgiST7ecevsVDTgc2obs= github.com/nishanths/predeclared v0.2.2 h1:V2EPdZPliZymNAn79T8RkNApBjMmVKh5XRpLm/w98Vk= github.com/nishanths/predeclared v0.2.2/go.mod h1:RROzoN6TnGQupbC+lqggsOlcgysk3LMK/HI84Mp280c= -github.com/nitrictech/nitric/core v0.0.0-20240913000004-5d21c28b00ba h1:ZIPl9waqhbqw3xB2+zpUI2T1kEHyMkOnZZMt6tdrEUM= -github.com/nitrictech/nitric/core v0.0.0-20240913000004-5d21c28b00ba/go.mod h1:4LQH9hea9rX+0A+8G47NRk5nZuXCDqiwci1aZsHAkU8= github.com/nitrictech/nitric/core v0.0.0-20241003062412-76ea6275fb0b h1:ImQFk66gRM3v9A6qmPImOiV3HJMDAX93X5rplMKn6ok= github.com/nitrictech/nitric/core v0.0.0-20241003062412-76ea6275fb0b/go.mod h1:9bQnYPqLzq8CcPk5MHT3phg19CWJhDlFOfdIv27lwwM= github.com/nitrictech/protoutils v0.0.0-20220321044654-02667a814cdf h1:8MB8W8ylM8sCM2COGfiO39/tB6BTdiawLszaUGCNL5w= @@ -481,8 +479,7 @@ github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1: github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw= -github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI= +github.com/prometheus/client_model v0.6.0 h1:k1v3CzpSRUTrKMppY35TLwPvxHqBu0bYgxZzqGIgaos= github.com/prometheus/client_model v0.6.0/go.mod h1:NTQHnmxFpouOD0DpvP4XujX3CdOAGQPoaGhyTchlyt8= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= @@ -643,8 +640,7 @@ go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/automaxprocs v1.5.3 h1:kWazyxZUrS3Gs4qUpbwo5kEIMGe/DAvi5Z4tl2NW4j8= go.uber.org/automaxprocs v1.5.3/go.mod h1:eRbA25aqJrxAbsLO0xy5jVwPt7FQnRgjW+efnwa1WM0= -go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A= -go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.8.0 h1:dg6GjLku4EH+249NNmoIciG9N/jURbDG+pFlTkhzIC8= go.uber.org/multierr v1.8.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= diff --git a/licenses.go b/licenses.go index 7656872..c29298c 100644 --- a/licenses.go +++ b/licenses.go @@ -20,6 +20,7 @@ package main import ( _ "github.com/nitrictech/go-sdk/nitric" _ "github.com/nitrictech/go-sdk/nitric/apis" + _ "github.com/nitrictech/go-sdk/nitric/batch" _ "github.com/nitrictech/go-sdk/nitric/errors" _ "github.com/nitrictech/go-sdk/nitric/keyvalue" _ "github.com/nitrictech/go-sdk/nitric/queues" diff --git a/nitric/apis/workers.go b/nitric/apis/worker.go similarity index 69% rename from nitric/apis/workers.go rename to nitric/apis/worker.go index 0305b98..42719b6 100644 --- a/nitric/apis/workers.go +++ b/nitric/apis/worker.go @@ -17,7 +17,6 @@ package apis import ( "context" errorsstd "errors" - "io" grpcx "github.com/nitrictech/go-sdk/internal/grpc" "github.com/nitrictech/go-sdk/internal/handlers" @@ -40,7 +39,7 @@ type apiWorkerOpts struct { var _ workers.StreamWorker = (*apiWorker)(nil) -// Start implements Worker. +// Start runs the API worker, creating a stream to the Nitric server func (a *apiWorker) Start(ctx context.Context) error { initReq := &v1.ClientMessage{ Content: &v1.ClientMessage_RegistrationRequest{ @@ -48,46 +47,40 @@ func (a *apiWorker) Start(ctx context.Context) error { }, } - stream, err := a.client.Serve(ctx) - if err != nil { - return err - } - - err = stream.Send(initReq) - if err != nil { - return err + createStream := func(ctx context.Context) (workers.Stream[v1.ClientMessage, v1.RegistrationResponse, *v1.ServerMessage], error) { + return a.client.Serve(ctx) } - for { - var ctx *Ctx - - resp, err := stream.Recv() - - if errorsstd.Is(err, io.EOF) { - err = stream.CloseSend() - if err != nil { - return err - } + handlerSrvMsg := func(msg *v1.ServerMessage) (*v1.ClientMessage, error) { + if msg.GetRegistrationResponse() != nil { + // No need to respond to the registration response + return nil, nil + } - return nil - } else if err == nil && resp.GetRegistrationResponse() != nil { - // There is no need to respond to the registration response - } else if err == nil && resp.GetHttpRequest() != nil { - ctx = NewCtx(resp) + if msg.GetHttpRequest() != nil { + handlerCtx := NewCtx(msg) - err = a.Handler(ctx) + err := a.Handler(handlerCtx) if err != nil { - ctx.WithError(err) + handlerCtx.WithError(err) } - err = stream.Send(ctx.ToClientMessage()) - if err != nil { - return err - } - } else { - return err + return handlerCtx.ToClientMessage(), nil } + + return nil, errors.NewWithCause( + codes.Internal, + "ApiWorker: Unhandled server message", + errorsstd.New("unhandled server message"), + ) } + + return workers.HandleStream( + ctx, + createStream, + initReq, + handlerSrvMsg, + ) } func newApiWorker(opts *apiWorkerOpts) *apiWorker { diff --git a/nitric/batch/batch_workers.go b/nitric/batch/worker.go similarity index 70% rename from nitric/batch/batch_workers.go rename to nitric/batch/worker.go index 7b48710..74300eb 100644 --- a/nitric/batch/batch_workers.go +++ b/nitric/batch/worker.go @@ -16,7 +16,6 @@ package batch import ( "context" - "io" "google.golang.org/grpc" @@ -25,6 +24,7 @@ import ( "github.com/nitrictech/go-sdk/constants" "github.com/nitrictech/go-sdk/nitric/errors" "github.com/nitrictech/go-sdk/nitric/errors/codes" + "github.com/nitrictech/go-sdk/nitric/workers" v1 "github.com/nitrictech/nitric/core/pkg/proto/batch/v1" ) @@ -38,7 +38,7 @@ type jobWorkerOpts struct { Handler Handler } -// Start implements Worker. +// Start runs the Job worker, creating a stream to the Nitric server func (s *jobWorker) Start(ctx context.Context) error { initReq := &v1.ClientMessage{ Content: &v1.ClientMessage_RegistrationRequest{ @@ -46,45 +46,30 @@ func (s *jobWorker) Start(ctx context.Context) error { }, } - // Create the request stream and send the initial request - stream, err := s.client.HandleJob(ctx) - if err != nil { - return err - } - - err = stream.Send(initReq) - if err != nil { - return err + createStream := func(ctx context.Context) (workers.Stream[v1.ClientMessage, v1.RegistrationResponse, *v1.ServerMessage], error) { + return s.client.HandleJob(ctx) } - for { - var ctx *Ctx - resp, err := stream.Recv() + handleSrvMsg := func(msg *v1.ServerMessage) (*v1.ClientMessage, error) { + if msg.GetJobRequest() != nil { + handlerCtx := NewCtx(msg) - if errorsstd.Is(err, io.EOF) { - err = stream.CloseSend() + err := s.handler(handlerCtx) if err != nil { - return err + handlerCtx.WithError(err) } - return nil - } else if err == nil && resp.GetRegistrationResponse() != nil { - // Do nothing - } else if err == nil && resp.GetJobRequest() != nil { - ctx = NewCtx(resp) - err = s.handler(ctx) - if err != nil { - ctx.WithError(err) - } - - err = stream.Send(ctx.ToClientMessage()) - if err != nil { - return err - } - } else { - return err + return handlerCtx.ToClientMessage(), nil } + + return nil, errors.NewWithCause( + codes.Internal, + "JobWorker: Unhandled server message", + errorsstd.New("unhandled server message"), + ) } + + return workers.HandleStream(ctx, createStream, initReq, handleSrvMsg) } func newJobWorker(opts *jobWorkerOpts) *jobWorker { diff --git a/nitric/nitric.go b/nitric/nitric.go index d141ac4..2297adb 100644 --- a/nitric/nitric.go +++ b/nitric/nitric.go @@ -15,6 +15,12 @@ package nitric import ( + "context" + "fmt" + "os" + "os/signal" + "syscall" + "github.com/nitrictech/go-sdk/nitric/apis" "github.com/nitrictech/go-sdk/nitric/batch" "github.com/nitrictech/go-sdk/nitric/keyvalue" @@ -42,7 +48,19 @@ var ( ) func Run() { - err := workers.GetDefaultManager().Run() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT) + + go func() { + <-sigChan + fmt.Printf("Received signal, shutting down...\n") + cancel() + }() + + err := workers.GetDefaultManager().Run(ctx) if err != nil { panic(err) } diff --git a/nitric/schedules/schedule_workers.go b/nitric/schedules/worker.go similarity index 71% rename from nitric/schedules/schedule_workers.go rename to nitric/schedules/worker.go index d94fa68..6ea9bce 100644 --- a/nitric/schedules/schedule_workers.go +++ b/nitric/schedules/worker.go @@ -17,12 +17,12 @@ package schedules import ( "context" errorsstd "errors" - "io" grpcx "github.com/nitrictech/go-sdk/internal/grpc" "github.com/nitrictech/go-sdk/internal/handlers" "github.com/nitrictech/go-sdk/nitric/errors" "github.com/nitrictech/go-sdk/nitric/errors/codes" + "github.com/nitrictech/go-sdk/nitric/workers" v1 "github.com/nitrictech/nitric/core/pkg/proto/schedules/v1" ) @@ -36,7 +36,7 @@ type scheduleWorkerOpts struct { Handler handlers.Handler[Ctx] } -// Start implements Worker. +// Start runs the Schedule worker, creating a stream to the Nitric server func (i *scheduleWorker) Start(ctx context.Context) error { initReq := &v1.ClientMessage{ Content: &v1.ClientMessage_RegistrationRequest{ @@ -44,45 +44,34 @@ func (i *scheduleWorker) Start(ctx context.Context) error { }, } - // Create the request stream and send the initial request - stream, err := i.client.Schedule(ctx) - if err != nil { - return err + createStream := func(ctx context.Context) (workers.Stream[v1.ClientMessage, v1.RegistrationResponse, *v1.ServerMessage], error) { + return i.client.Schedule(ctx) } - err = stream.Send(initReq) - if err != nil { - return err - } - for { - var ctx *Ctx - - resp, err := stream.Recv() - - if errorsstd.Is(err, io.EOF) { - err = stream.CloseSend() - if err != nil { - return err - } + handlerSrvMsg := func(msg *v1.ServerMessage) (*v1.ClientMessage, error) { + if msg.GetIntervalRequest() != nil { + handlerCtx := NewCtx(msg) - return nil - } else if err == nil && resp.GetRegistrationResponse() != nil { - // There is no need to respond to the registration response - } else if err == nil && resp.GetIntervalRequest() != nil { - ctx = NewCtx(resp) - err = i.handler(ctx) + err := i.handler(handlerCtx) if err != nil { - ctx.WithError(err) + handlerCtx.WithError(err) } - err = stream.Send(ctx.ToClientMessage()) - if err != nil { - return err - } - } else { - return err + return handlerCtx.ToClientMessage(), nil } + return nil, errors.NewWithCause( + codes.Internal, + "ScheduleWorker: Unhandled server message", + errorsstd.New("unhandled server message"), + ) } + + return workers.HandleStream( + ctx, + createStream, + initReq, + handlerSrvMsg, + ) } func newScheduleWorker(opts *scheduleWorkerOpts) *scheduleWorker { diff --git a/nitric/storage/bucket_workers.go b/nitric/storage/worker.go similarity index 71% rename from nitric/storage/bucket_workers.go rename to nitric/storage/worker.go index 6e1ba2e..a14b59a 100644 --- a/nitric/storage/bucket_workers.go +++ b/nitric/storage/worker.go @@ -17,12 +17,12 @@ package storage import ( "context" errorsstd "errors" - "io" grpcx "github.com/nitrictech/go-sdk/internal/grpc" "github.com/nitrictech/go-sdk/internal/handlers" "github.com/nitrictech/go-sdk/nitric/errors" "github.com/nitrictech/go-sdk/nitric/errors/codes" + "github.com/nitrictech/go-sdk/nitric/workers" v1 "github.com/nitrictech/nitric/core/pkg/proto/storage/v1" ) @@ -36,7 +36,7 @@ type bucketEventWorkerOpts struct { Handler handlers.Handler[Ctx] } -// Start implements Worker. +// Start runs the BucketEvent worker, creating a stream to the Nitric server func (b *bucketEventWorker) Start(ctx context.Context) error { initReq := &v1.ClientMessage{ Content: &v1.ClientMessage_RegistrationRequest{ @@ -44,45 +44,30 @@ func (b *bucketEventWorker) Start(ctx context.Context) error { }, } - // Create the request stream and send the initial request - stream, err := b.client.Listen(ctx) - if err != nil { - return err - } - - err = stream.Send(initReq) - if err != nil { - return err + createStream := func(ctx context.Context) (workers.Stream[v1.ClientMessage, v1.RegistrationResponse, *v1.ServerMessage], error) { + return b.client.Listen(ctx) } - for { - var ctx *Ctx - resp, err := stream.Recv() + handlerSrvMsg := func(msg *v1.ServerMessage) (*v1.ClientMessage, error) { + if msg.GetBlobEventRequest() != nil { + handlerCtx := NewCtx(msg) - if errorsstd.Is(err, io.EOF) { - err = stream.CloseSend() + err := b.handler(handlerCtx) if err != nil { - return err + handlerCtx.WithError(err) } - return nil - } else if err == nil && resp.GetRegistrationResponse() != nil { - // There is no need to respond to the registration response - } else if err == nil && resp.GetBlobEventRequest() != nil { - ctx = NewCtx(resp) - err = b.handler(ctx) - if err != nil { - ctx.WithError(err) - } - - err = stream.Send(ctx.ToClientMessage()) - if err != nil { - return err - } - } else { - return err + return handlerCtx.ToClientMessage(), nil } + + return nil, errors.NewWithCause( + codes.Internal, + "BucketEventWorker: Unhandled server message", + errorsstd.New("unhandled server message"), + ) } + + return workers.HandleStream(ctx, createStream, initReq, handlerSrvMsg) } func newBucketEventWorker(opts *bucketEventWorkerOpts) *bucketEventWorker { diff --git a/nitric/topics/topic_workers.go b/nitric/topics/worker.go similarity index 72% rename from nitric/topics/topic_workers.go rename to nitric/topics/worker.go index 7d6b7a1..beb54dc 100644 --- a/nitric/topics/topic_workers.go +++ b/nitric/topics/worker.go @@ -16,7 +16,6 @@ package topics import ( "context" - "io" errorsstd "errors" @@ -24,6 +23,7 @@ import ( "github.com/nitrictech/go-sdk/internal/handlers" "github.com/nitrictech/go-sdk/nitric/errors" "github.com/nitrictech/go-sdk/nitric/errors/codes" + "github.com/nitrictech/go-sdk/nitric/workers" v1 "github.com/nitrictech/nitric/core/pkg/proto/topics/v1" ) @@ -45,45 +45,30 @@ func (s *subscriptionWorker) Start(ctx context.Context) error { }, } - // Create the request stream and send the initial request - stream, err := s.client.Subscribe(ctx) - if err != nil { - return err - } - - err = stream.Send(initReq) - if err != nil { - return err + createStream := func(ctx context.Context) (workers.Stream[v1.ClientMessage, v1.RegistrationResponse, *v1.ServerMessage], error) { + return s.client.Subscribe(ctx) } - for { - var ctx *Ctx - resp, err := stream.Recv() + handleSrvMsg := func(msg *v1.ServerMessage) (*v1.ClientMessage, error) { + if msg.GetMessageRequest() != nil { + handlerCtx := NewCtx(msg) - if errorsstd.Is(err, io.EOF) { - err = stream.CloseSend() + err := s.handler(handlerCtx) if err != nil { - return err + handlerCtx.WithError(err) } - return nil - } else if err == nil && resp.GetRegistrationResponse() != nil { - // There is no need to respond to the registration response - } else if err == nil && resp.GetMessageRequest() != nil { - ctx = NewCtx(resp) - err = s.handler(ctx) - if err != nil { - ctx.WithError(err) - } - - err = stream.Send(ctx.ToClientMessage()) - if err != nil { - return err - } - } else { - return err + return handlerCtx.ToClientMessage(), nil } + + return nil, errors.NewWithCause( + codes.Internal, + "SubscriptionWorker: Unhandled server message", + errorsstd.New("unhandled server message"), + ) } + + return workers.HandleStream(ctx, createStream, initReq, handleSrvMsg) } func newSubscriptionWorker(opts *subscriptionWorkerOpts) *subscriptionWorker { diff --git a/nitric/websockets/websocket_workers.go b/nitric/websockets/worker.go similarity index 72% rename from nitric/websockets/websocket_workers.go rename to nitric/websockets/worker.go index e0aafb9..eac8398 100644 --- a/nitric/websockets/websocket_workers.go +++ b/nitric/websockets/worker.go @@ -17,12 +17,12 @@ package websockets import ( "context" errorsstd "errors" - "io" grpcx "github.com/nitrictech/go-sdk/internal/grpc" "github.com/nitrictech/go-sdk/internal/handlers" "github.com/nitrictech/go-sdk/nitric/errors" "github.com/nitrictech/go-sdk/nitric/errors/codes" + "github.com/nitrictech/go-sdk/nitric/workers" v1 "github.com/nitrictech/nitric/core/pkg/proto/websockets/v1" ) @@ -44,45 +44,29 @@ func (w *websocketWorker) Start(ctx context.Context) error { }, } - // Create the request stream and send the initial request - stream, err := w.client.HandleEvents(ctx) - if err != nil { - return err - } - - err = stream.Send(initReq) - if err != nil { - return err + createStream := func(ctx context.Context) (workers.Stream[v1.ClientMessage, v1.RegistrationResponse, *v1.ServerMessage], error) { + return w.client.HandleEvents(ctx) } - for { - var ctx *Ctx - - resp, err := stream.Recv() - - if errorsstd.Is(err, io.EOF) { - err = stream.CloseSend() - if err != nil { - return err - } - return nil - } else if err == nil && resp.GetRegistrationResponse() != nil { - // There is no need to respond to the registration response - } else if err == nil && resp.GetWebsocketEventRequest() != nil { - ctx = NewCtx(resp) - err = w.handler(ctx) - if err != nil { - ctx.WithError(err) - } + handlerSrvMsg := func(msg *v1.ServerMessage) (*v1.ClientMessage, error) { + if msg.GetWebsocketEventRequest() != nil { + handlerCtx := NewCtx(msg) - err = stream.Send(ctx.ToClientMessage()) + err := w.handler(handlerCtx) if err != nil { - return err + handlerCtx.WithError(err) } - } else { - return err + return handlerCtx.ToClientMessage(), nil } + + return nil, errors.NewWithCause( + codes.Internal, + "WebsocketWorker: Unhandled server message", + errorsstd.New("unhandled server message"), + ) } + + return workers.HandleStream(ctx, createStream, initReq, handlerSrvMsg) } func newWebsocketWorker(opts *websocketWorkerOpts) *websocketWorker { diff --git a/nitric/workers/manager.go b/nitric/workers/manager.go index 3a55d33..acfaeee 100644 --- a/nitric/workers/manager.go +++ b/nitric/workers/manager.go @@ -137,7 +137,7 @@ func (m *Manager) RegisterPolicy(res *v1.ResourceIdentifier, actions ...v1.Actio return nil } -func (m *Manager) Run() error { +func (m *Manager) Run(ctx context.Context) error { wg := sync.WaitGroup{} errList := &multierror.ErrorList{} @@ -146,7 +146,7 @@ func (m *Manager) Run() error { go func(s StreamWorker) { defer wg.Done() - if err := s.Start(context.TODO()); err != nil { + if err := s.Start(ctx); err != nil { if isBuildEnvironment() && isEOF(err) { // ignore the EOF error when running code-as-config. return diff --git a/nitric/workers/workers.go b/nitric/workers/workers.go index a8d0f81..d63f2c9 100644 --- a/nitric/workers/workers.go +++ b/nitric/workers/workers.go @@ -14,8 +14,87 @@ package workers -import "context" +import ( + "context" + "errors" + "fmt" + "io" + + "google.golang.org/grpc" +) type StreamWorker interface { Start(context.Context) error } + +type StdServerMsg[RegistrationResponse any] interface { + GetRegistrationResponse() *RegistrationResponse +} + +type Stream[ClientMessage any, RegistrationResponse any, ServerMessage StdServerMsg[RegistrationResponse]] interface { + Send(*ClientMessage) error + Recv() (ServerMessage, error) + grpc.ClientStream +} + +// HandleStream runs a nitric worker, in the standard request/response pattern. +// No changes needed here other than the updated types in the signature. +func HandleStream[ClientMessage any, RegistrationResponse any, ServerMessage StdServerMsg[RegistrationResponse]]( + ctx context.Context, + createStream func(ctx context.Context) (Stream[ClientMessage, RegistrationResponse, ServerMessage], error), + initReq *ClientMessage, + handleServerMsg func(msg ServerMessage) (*ClientMessage, error), +) error { + stream, err := createStream(ctx) + if err != nil { + return err + } + + err = stream.Send(initReq) + if err != nil { + return err + } + + for { + select { + case <-ctx.Done(): + fmt.Printf("Context canceled, closing stream\n") + // If the context is canceled, close the stream and return + err := stream.CloseSend() + if err != nil { + return err + } + return nil + + default: + // Receive the next message + serverMsg, err := stream.Recv() + + if errors.Is(err, io.EOF) { + // Close the stream and exit normally on EOF + err = stream.CloseSend() + if err != nil { + return err + } + return nil + } else if err != nil { + return err + } + + if serverMsg.GetRegistrationResponse() != nil { + // No need to respond to the registration responses (they're just acks) + continue + } + + clientMsg, err := handleServerMsg(serverMsg) + if err != nil { + return err + } + + err = stream.Send(clientMsg) + if err != nil { + return err + } + } + } +}