Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce --use-latest-expiring-certificate Option #98

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 92 additions & 30 deletions aws_signing_helper/cert_store_signer_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"fmt"
"io"
"log"
"sort"
"unsafe"
)

Expand All @@ -43,7 +44,7 @@ const (

// Gets the matching identity and certificate for this CertIdentifier
// If there is more than one, only a list of the matching certificates is returned
func GetMatchingCertsAndIdentity(certIdentifier CertIdentifier) (C.SecIdentityRef, C.SecCertificateRef, []CertificateContainer, error) {
func GetMatchingCertsAndIdentity(certIdentifier CertIdentifier) ([]C.SecIdentityRef, []C.SecCertificateRef, []CertificateContainer, error) {
queryMap := map[C.CFTypeRef]C.CFTypeRef{
C.CFTypeRef(C.kSecClass): C.CFTypeRef(C.kSecClassIdentity),
C.CFTypeRef(C.kSecReturnRef): C.CFTypeRef(C.kCFBooleanTrue),
Expand All @@ -52,16 +53,16 @@ func GetMatchingCertsAndIdentity(certIdentifier CertIdentifier) (C.SecIdentityRe

query := mapToCFDictionary(queryMap)
if query == 0 {
return 0, 0, nil, errors.New("error creating CFDictionary")
return nil, nil, nil, errors.New("error creating CFDictionary")
}
defer C.CFRelease(C.CFTypeRef(query))

var absResult C.CFTypeRef
if err := osStatusError(C.SecItemCopyMatching(query, &absResult)); err != nil {
if err == errSecItemNotFound {
return 0, 0, nil, errors.New("unable to find matching identity in cert store")
return nil, nil, nil, errors.New("unable to find matching identity in cert store")
}
return 0, 0, nil, err
return nil, nil, nil, err
}
defer C.CFRelease(C.CFTypeRef(absResult))
aryResult := C.CFArrayRef(absResult)
Expand All @@ -71,13 +72,14 @@ func GetMatchingCertsAndIdentity(certIdentifier CertIdentifier) (C.SecIdentityRe
identRefs := make([]C.CFTypeRef, numIdentRefs)
C.CFArrayGetValues(aryResult, C.CFRange{0, numIdentRefs}, (*unsafe.Pointer)(unsafe.Pointer(&identRefs[0])))
var certContainers []CertificateContainer
var certRef C.SecCertificateRef
var identRef C.SecIdentityRef
var certRefs []C.SecCertificateRef
var outputIdentRefs []C.SecIdentityRef
var isMatch bool
certContainerIndex := 0
for _, curIdentRef := range identRefs {
curCertRef, err := getCertRef(C.SecIdentityRef(curIdentRef))
if err != nil {
return 0, 0, nil, errors.New("unable to get cert ref")
return nil, nil, nil, errors.New("unable to get cert ref")
}
curCert, err := exportCertRef(curCertRef)
if err != nil {
Expand All @@ -90,14 +92,20 @@ func GetMatchingCertsAndIdentity(certIdentifier CertIdentifier) (C.SecIdentityRe
// Find whether there is a matching certificate
isMatch = certMatches(certIdentifier, *curCert)
if isMatch {
certContainers = append(certContainers, CertificateContainer{curCert, ""})
certContainers = append(certContainers, CertificateContainer{certContainerIndex, curCert, ""})
certContainerIndex += 1
// Assign to certRef and identRef at most once in the loop
// Both values are only useful if there is exactly one match in the certificate store
// When creating a signer, there has to be exactly one matching certificate
if certRef == 0 {
certRef = curCertRef
identRef = C.SecIdentityRef(curIdentRef)
}
// if certRef == 0 {
// certRef = curCertRef
// identRef = C.SecIdentityRef(curIdentRef)
// }

certRefs = append(certRefs, curCertRef)
// Note that only the SecIdentityRef needs to be retained since it was neither created nor copied
C.CFRetain(C.CFTypeRef(curIdentRef))
outputIdentRefs = append(outputIdentRefs, C.SecIdentityRef(curIdentRef))
}

nextIteration:
Expand All @@ -109,37 +117,74 @@ func GetMatchingCertsAndIdentity(certIdentifier CertIdentifier) (C.SecIdentityRe

// Only retain the SecIdentityRef if it should be used later on
// Note that only the SecIdentityRef needs to be retained since it was neither created nor copied
if len(certContainers) == 1 {
C.CFRetain(C.CFTypeRef(identRef))
return identRef, certRef, certContainers, nil
} else {
return 0, 0, certContainers, nil
}
// if len(certContainers) == 1 {
// C.CFRetain(C.CFTypeRef(identRef))
// return identRef, certRef, certContainers, nil
// } else {
// return 0, 0, certContainers, nil
// }

// It's the caller's responsibility to release each SecIdentityRef after use.
return outputIdentRefs, certRefs, certContainers, nil
}

// Gets the certificates that match the CertIdentifier
func GetMatchingCerts(certIdentifier CertIdentifier) ([]CertificateContainer, error) {
identRef, certRef, certContainers, err := GetMatchingCertsAndIdentity(certIdentifier)
if len(certContainers) == 1 {
identRefs, certRefs, certContainers, err := GetMatchingCertsAndIdentity(certIdentifier)
for i, identRef := range identRefs {
C.CFRelease(C.CFTypeRef(identRef))
identRefs[i] = 0
}
for i, certRef := range certRefs {
C.CFRelease(C.CFTypeRef(certRef))
certRefs[i] = 0
}
return certContainers, err
}

// Creates a DarwinCertStoreSigner based on the identifying certificate
func GetCertStoreSigner(certIdentifier CertIdentifier) (signer Signer, signingAlgorithm string, err error) {
identRef, certRef, certContainers, err := GetMatchingCertsAndIdentity(certIdentifier)
func GetCertStoreSigner(certIdentifier CertIdentifier, useLatestExpiringCert bool) (signer Signer, signingAlgorithm string, err error) {
var (
selectedCertContainer CertificateContainer
cert *x509.Certificate
identRef C.SecIdentityRef
certRef C.SecCertificateRef
keyRef C.SecKeyRef
)

identRefs, certRefs, certContainers, err := GetMatchingCertsAndIdentity(certIdentifier)
if err != nil {
return nil, "", err
goto fail
}
if len(certContainers) == 0 {
return nil, "", errors.New("no matching identities")
err = errors.New("no matching identities")
goto fail
}
if useLatestExpiringCert {
sort.Sort(CertificateContainerList(certContainers))
// Release the `SecIdentityRef`s and `SecCertificateRef`s that won't be used
for i, certContainer := range certContainers {
if i != len(certContainers)-1 {
C.CFRelease(C.CFTypeRef(identRefs[certContainer.Index]))
C.CFRelease(C.CFTypeRef(certRefs[certContainer.Index]))

identRefs[certContainer.Index] = 0
certRefs[certContainer.Index] = 0
}
}
} else {
if len(certContainers) > 1 {
err = errors.New("multiple matching identities")
goto fail
}
}
if len(certContainers) > 1 {
return nil, "", errors.New("multiple matching identities")
selectedCertContainer = certContainers[len(certContainers)-1]
if Debug {
log.Print(fmt.Sprintf("selected certificate: %s", DefaultCertContainerToString(selectedCertContainer)))
}
cert := certContainers[0].Cert
cert = selectedCertContainer.Cert
certRef = certRefs[selectedCertContainer.Index]
identRef = identRefs[selectedCertContainer.Index]

// Find the signing algorithm
switch cert.PublicKey.(type) {
Expand All @@ -148,15 +193,32 @@ func GetCertStoreSigner(certIdentifier CertIdentifier) (signer Signer, signingAl
case *rsa.PublicKey:
signingAlgorithm = aws4_x509_rsa_sha256
default:
return nil, "", errors.New("unsupported algorithm")
err = errors.New("unsupported algorithm")
goto fail
}

keyRef, err := getKeyRef(identRef)
keyRef, err = getKeyRef(identRef)
if err != nil {
return nil, "", errors.New("unable to get key reference")
err = errors.New("unable to get key reference")
goto fail
}

return &DarwinCertStoreSigner{identRef, keyRef, certRef, cert, nil}, signingAlgorithm, nil

fail:
for i, identRef := range identRefs {
if identRef != 0 {
C.CFRelease(C.CFTypeRef(identRef))
identRefs[i] = 0
}
}
for i, certRef := range certRefs {
if certRef != 0 {
C.CFRelease(C.CFTypeRef(certRef))
certRefs[i] = 0
}
}
return nil, "", err
}

// Gets the certificate associated with this DarwinCertStoreSigner
Expand Down
2 changes: 1 addition & 1 deletion aws_signing_helper/cert_store_signer_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ func GetMatchingCerts(certIdentifier CertIdentifier) ([]CertificateContainer, er
return nil, errors.New("unable to use cert store signer on linux")
}

func GetCertStoreSigner(certIdentifier CertIdentifier) (signer Signer, signingAlgorithm string, err error) {
func GetCertStoreSigner(certIdentifier CertIdentifier, useLatestExpiringCert bool) (signer Signer, signingAlgorithm string, err error) {
return nil, "", errors.New("unable to use cert store signer on linux")
}
92 changes: 66 additions & 26 deletions aws_signing_helper/cert_store_signer_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"golang.org/x/sys/windows"
"io"
"log"
"sort"
"strconv"
"strings"
"unsafe"
Expand Down Expand Up @@ -145,7 +146,7 @@ func (secStatus securityStatus) Error() string {
// Gets the certificates that match the given CertIdentifier within the user's specified system
// certificate store. By default, that is "MY".
// If there is only a single matching certificate, then its chain will be returned too
func GetMatchingCertsAndChain(certIdentifier CertIdentifier) (store windows.Handle, certCtx *windows.CertContext, certChain []*x509.Certificate, certContainers []CertificateContainer, err error) {
func GetMatchingCertsAndChain(certIdentifier CertIdentifier) (store windows.Handle, certCtxs []*windows.CertContext, certChains [][]*x509.Certificate, certContainers []CertificateContainer, err error) {
storeName, err := windows.UTF16PtrFromString(certIdentifier.SystemStoreName)
if err != nil {
return 0, nil, nil, nil, errors.New("unable to UTF-16 encode personal certificate store name")
Expand All @@ -164,12 +165,14 @@ func GetMatchingCertsAndChain(certIdentifier CertIdentifier) (store windows.Hand
params windows.CertChainFindByIssuerPara
paramsPtr unsafe.Pointer
chainCtx *windows.CertChainContext = nil
certCtx *windows.CertContext
)
params.Size = uint32(unsafe.Sizeof(params))
paramsPtr = unsafe.Pointer(&params)

var curCertCtx *windows.CertContext
var curCert *x509.Certificate
certContainerIndex := 0
for {
// Previous chainCtx should be freed here if it isn't nil
chainCtx, err = windows.CertFindChainInStore(store, encoding, flags, findType, paramsPtr, chainCtx)
Expand Down Expand Up @@ -212,19 +215,26 @@ func GetMatchingCertsAndChain(certIdentifier CertIdentifier) (store windows.Hand

curCert = x509CertChain[0]
if certMatches(certIdentifier, *curCert) {
certContainers = append(certContainers, CertificateContainer{curCert, ""})
certContainers = append(certContainers, CertificateContainer{certContainerIndex, curCert, ""})
certContainerIndex += 1

// Assign to certChain and certCtx at most once in the loop.
// The value is only useful if there is exactly one match in the certificate store.
// When creating a signer, there has to be exactly one matching certificate.
if certChain == nil {
certChain = x509CertChain[:]
certCtx = chainElts[0].CertContext
// This is required later on when creating the WindowsCertStoreSigner
// If this method isn't being called in order to create a WindowsCertStoreSigner,
// this return value will have to be freed explicitly.
windows.CertDuplicateCertificateContext(certCtx)
}
// if certChain == nil {
// certChain = x509CertChain[:]
// certCtx = chainElts[0].CertContext
// // This is required later on when creating the WindowsCertStoreSigner
// // If this method isn't being called in order to create a WindowsCertStoreSigner,
// // this return value will have to be freed explicitly.
// windows.CertDuplicateCertificateContext(certCtx)
// }

certChains = append(certChains, x509CertChain[:])
certCtx = chainElts[0].CertContext
// It's the responsibility of the caller to free the below once they are done using it.
windows.CertDuplicateCertificateContext(certCtx)
certCtxs = append(certCtxs, certCtx)
}

nextIteration:
Expand All @@ -234,14 +244,18 @@ func GetMatchingCertsAndChain(certIdentifier CertIdentifier) (store windows.Hand
log.Printf("found %d matching identities\n", len(certContainers))
}

return store, certCtx, certChain, certContainers, nil
return store, certCtxs, certChains, certContainers, nil

fail:
if chainCtx != nil {
windows.CertFreeCertificateChain(chainCtx)
chainCtx = nil
}
if certCtx != nil {
windows.CertFreeCertificateContext(certCtx)
for i, curCertCtx := range certCtxs {
if curCertCtx != nil {
windows.CertFreeCertificateContext(curCertCtx)
certCtxs[i] = nil
}
}
windows.CertCloseStore(store, 0)

Expand All @@ -250,32 +264,55 @@ fail:

// Gets the certificates that match a CertIdentifier
func GetMatchingCerts(certIdentifier CertIdentifier) ([]CertificateContainer, error) {
store, certCtx, _, certContainers, err := GetMatchingCertsAndChain(certIdentifier)
if certCtx != nil {
windows.CertFreeCertificateContext(certCtx)
store, certCtxs, _, certContainers, err := GetMatchingCertsAndChain(certIdentifier)
for i, curCertCtx := range certCtxs {
if curCertCtx != nil {
windows.CertFreeCertificateContext(curCertCtx)
certCtxs[i] = nil
}
}
windows.CertCloseStore(store, 0)

return certContainers, err
}

// Gets a WindowsCertStoreSigner based on the CertIdentifier
func GetCertStoreSigner(certIdentifier CertIdentifier) (signer Signer, signingAlgorithm string, err error) {
var privateKey *winPrivateKey
store, certCtx, certChain, certContainers, err := GetMatchingCertsAndChain(certIdentifier)
func GetCertStoreSigner(certIdentifier CertIdentifier, useLatestExpiringCert bool) (signer Signer, signingAlgorithm string, err error) {
var (
privateKey *winPrivateKey
selectedCertContainer CertificateContainer
cert *x509.Certificate
certCtx *windows.CertContext
certChain []*x509.Certificate
)

store, certCtxs, certChains, certContainers, err := GetMatchingCertsAndChain(certIdentifier)
if err != nil {
goto fail
}
if len(certContainers) > 1 {
err = errors.New("more than one matching cert found in cert store")
goto fail
}
if len(certContainers) == 0 {
err = errors.New("no matching certs found in cert store")
goto fail
}

signer = &WindowsCertStoreSigner{store: store, cert: certContainers[0].Cert, certCtx: certCtx, certChain: certChain}
if useLatestExpiringCert {
sort.Sort(CertificateContainerList(certContainers))
} else {
if len(certContainers) > 1 {
err = errors.New("multiple matching identities")
goto fail
}
}

selectedCertContainer = certContainers[len(certContainers)-1]
if Debug {
log.Printf("selected certificate: %s", DefaultCertContainerToString(selectedCertContainer))
}
cert = selectedCertContainer.Cert
certCtx = certCtxs[selectedCertContainer.Index]
certChain = certChains[selectedCertContainer.Index]

signer = &WindowsCertStoreSigner{store: store, cert: cert, certCtx: certCtx, certChain: certChain}

privateKey, err = signer.(*WindowsCertStoreSigner).getPrivateKey()
if err != nil {
Expand All @@ -296,8 +333,11 @@ func GetCertStoreSigner(certIdentifier CertIdentifier) (signer Signer, signingAl
return signer, signingAlgorithm, err

fail:
if certCtx != nil {
windows.CertFreeCertificateContext(certCtx)
for i, curCertCtx := range certCtxs {
if curCertCtx != nil {
windows.CertFreeCertificateContext(curCertCtx)
certCtxs[i] = nil
}
}
if signer != nil {
signer.Close()
Expand Down
Loading