diff --git a/internal/internal_nexus_task_handler.go b/internal/internal_nexus_task_handler.go index 6161d37a7..1627acf8d 100644 --- a/internal/internal_nexus_task_handler.go +++ b/internal/internal_nexus_task_handler.go @@ -71,6 +71,7 @@ type nexusTaskHandler struct { failureConverter converter.FailureConverter logger log.Logger metricsHandler metrics.Handler + registry *registry } func newNexusTaskHandler( @@ -83,6 +84,7 @@ func newNexusTaskHandler( failureConverter converter.FailureConverter, logger log.Logger, metricsHandler metrics.Handler, + registry *registry, ) *nexusTaskHandler { return &nexusTaskHandler{ nexusHandler: nexusHandler, @@ -94,6 +96,7 @@ func newNexusTaskHandler( taskQueueName: taskQueueName, client: client, metricsHandler: metricsHandler, + registry: registry, } } @@ -393,6 +396,7 @@ func (h *nexusTaskHandler) newNexusOperationContext(response *workflowservice.Po TaskQueue: h.taskQueueName, MetricsHandler: metricsHandler, Log: logger, + registry: h.registry, }, nil } diff --git a/internal/internal_nexus_worker.go b/internal/internal_nexus_worker.go index 26b38b1d8..2021cfd28 100644 --- a/internal/internal_nexus_worker.go +++ b/internal/internal_nexus_worker.go @@ -32,6 +32,7 @@ type nexusWorkerOptions struct { client Client workflowService workflowservice.WorkflowServiceClient handler nexus.Handler + registry *registry } type nexusWorker struct { @@ -57,6 +58,7 @@ func newNexusWorker(opts nexusWorkerOptions) (*nexusWorker, error) { opts.executionParameters.FailureConverter, opts.executionParameters.Logger, opts.executionParameters.MetricsHandler, + opts.registry, ), opts.workflowService, params, diff --git a/internal/internal_worker.go b/internal/internal_worker.go index 8355158c5..68cc9d771 100644 --- a/internal/internal_worker.go +++ b/internal/internal_worker.go @@ -1128,6 +1128,7 @@ func (aw *AggregatedWorker) start() error { client: aw.client, workflowService: aw.client.workflowService, handler: handler, + registry: aw.registry, }) if err != nil { return fmt.Errorf("failed to create a nexus worker: %w", err) diff --git a/internal/internal_workflow_testsuite.go b/internal/internal_workflow_testsuite.go index d446f9a09..134c7b26a 100644 --- a/internal/internal_workflow_testsuite.go +++ b/internal/internal_workflow_testsuite.go @@ -2408,6 +2408,7 @@ func (env *testWorkflowEnvironmentImpl) newTestNexusTaskHandler( env.failureConverter, env.logger, env.metricsHandler, + env.registry, ) } diff --git a/internal/nexus_operations.go b/internal/nexus_operations.go index 0d1136078..a55403c05 100644 --- a/internal/nexus_operations.go +++ b/internal/nexus_operations.go @@ -42,11 +42,16 @@ import ( // NexusOperationContext is an internal only struct that holds fields used by the temporalnexus functions. type NexusOperationContext struct { - Client Client - Namespace string - TaskQueue string - MetricsHandler metrics.Handler - Log log.Logger + Client Client + Namespace string + TaskQueue string + MetricsHandler metrics.Handler + Log log.Logger + registry *registry +} + +func (nc *NexusOperationContext) ResolveWorkflowName(wf any) (string, error) { + return getWorkflowFunctionName(nc.registry, wf) } type nexusOperationContextKeyType struct{} diff --git a/temporalnexus/operation.go b/temporalnexus/operation.go index 6cc078400..58b4b8b4e 100644 --- a/temporalnexus/operation.go +++ b/temporalnexus/operation.go @@ -324,6 +324,12 @@ func ExecuteUntypedWorkflow[R any]( if !ok { return nil, nexus.HandlerErrorf(nexus.HandlerErrorTypeInternal, "internal error") } + + workflowType, err := nctx.ResolveWorkflowName(workflow) + if err != nil { + panic(err) + } + if startWorkflowOptions.TaskQueue == "" { startWorkflowOptions.TaskQueue = nctx.TaskQueue } @@ -373,7 +379,7 @@ func ExecuteUntypedWorkflow[R any]( } internal.SetLinksOnStartWorkflowOptions(&startWorkflowOptions, links) - run, err := nctx.Client.ExecuteWorkflow(ctx, startWorkflowOptions, workflow, args...) + run, err := nctx.Client.ExecuteWorkflow(ctx, startWorkflowOptions, workflowType, args...) if err != nil { return nil, err } diff --git a/test/nexus_test.go b/test/nexus_test.go index 8a4f0b1f6..12dcb0f1e 100644 --- a/test/nexus_test.go +++ b/test/nexus_test.go @@ -626,7 +626,7 @@ func TestAsyncOperationFromWorkflow(t *testing.T) { service := nexus.NewService("test") require.NoError(t, service.Register(op)) w.RegisterNexusService(service) - w.RegisterWorkflow(handlerWorkflow) + w.RegisterWorkflowWithOptions(handlerWorkflow, workflow.RegisterOptions{Name: "foo"}) w.RegisterWorkflow(callerWorkflow) require.NoError(t, w.Start()) t.Cleanup(w.Stop) @@ -693,6 +693,9 @@ func TestAsyncOperationFromWorkflow(t *testing.T) { } } require.NotNil(t, targetEvent) + // Verify that calling by name works. + require.Equal(t, "foo", targetEvent.GetWorkflowExecutionStartedEventAttributes().WorkflowType.Name) + // Verify that links are properly attached. require.Len(t, targetEvent.GetLinks(), 1) require.True(t, proto.Equal( &common.Link_WorkflowEvent{