diff --git a/lib/auth/join_azure.go b/lib/auth/join_azure.go index fcfd43f90026b..609de17622fd6 100644 --- a/lib/auth/join_azure.go +++ b/lib/auth/join_azure.go @@ -245,7 +245,7 @@ func parseAndVerifyAttestedData(ctx context.Context, adBytes []byte, challenge s } // verifyVMIdentity verifies that the provided access token came from the -// correct Azure VM. Returns the Aure join attributes +// correct Azure VM. Returns the Azure join attributes func verifyVMIdentity( ctx context.Context, cfg *azureRegisterConfig, @@ -291,6 +291,9 @@ func verifyVMIdentity( // from the VM resource. vmSubscription, vmResourceGroup, err := claimsToIdentifiers(tokenClaims) if err == nil { + if subscriptionID != vmSubscription { + return nil, trace.AccessDenied("subscription ID mismatch between attested data and access token") + } return azureJoinToAttrs(vmSubscription, vmResourceGroup), nil } logger.WarnContext(ctx, "Failed to parse VM identifiers from claims. Retrying with Azure VM API.", diff --git a/lib/auth/join_azure_test.go b/lib/auth/join_azure_test.go index c7cc7c5b18954..2eeaff3b0e8c0 100644 --- a/lib/auth/join_azure_test.go +++ b/lib/auth/join_azure_test.go @@ -754,6 +754,28 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) { certs: []*x509.Certificate{tlsConfig.Certificate}, assertError: require.NoError, }, + { + name: "subscription mismatch between attestation and token", + requestTokenName: "test-token", + tokenSubscription: "attested-subscription", + tokenVMID: defaultVMID, + tokenManagedIdentityResourceID: vmResourceID("token-subscription", defaultResourceGroup, defaultVMName), + tokenSpec: types.ProvisionTokenSpecV2{ + Roles: []types.SystemRole{types.RoleNode}, + Azure: &types.ProvisionTokenSpecV2Azure{ + Allow: []*types.ProvisionTokenSpecV2Azure_Rule{ + { + Subscription: "token-subscription", + ResourceGroups: []string{defaultResourceGroup}, + }, + }, + }, + JoinMethod: types.JoinMethodAzure, + }, + verify: mockVerifyToken(nil), + certs: []*x509.Certificate{tlsConfig.Certificate}, + assertError: isAccessDenied, + }, } for _, tc := range tests {