Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to Go v2 SDK #99

Merged
merged 1 commit into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Migrate to Go v2 SDK
* Pull in files related to Go v2 SDK for Roles Anywhere
* Downstream changes to code calling CreateSession with new client
  • Loading branch information
13ajay committed Feb 19, 2025
commit 24d4dae0d4efcefea7b13b20fd151064f3455d7f
80 changes: 60 additions & 20 deletions aws_signing_helper/credentials.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
package aws_signing_helper

import (
"context"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"log"
"net/http"
"runtime"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/arn"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/arn"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/rolesanywhere-credential-helper/rolesanywhere"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
)

type CredentialsOpts struct {
Expand All @@ -38,6 +41,18 @@ type CredentialsOpts struct {
RoleSessionName string
}

// Middleware to set a custom user agent header
func createCredHelperUserAgentMiddleware(userAgent string) middleware.BuildMiddleware {
return middleware.BuildMiddlewareFunc("UserAgent", func(
ctx context.Context, input middleware.BuildInput, next middleware.BuildHandler,
) (middleware.BuildOutput, middleware.Metadata, error) {
if req, ok := input.Request.(*smithyhttp.Request); ok {
req.Header.Set("User-Agent", userAgent)
}
return next.HandleBuild(ctx, input)
})
}

// Function to create session and generate credentials
func GenerateCredentials(opts *CredentialsOpts, signer Signer, signatureAlgorithm string) (CredentialProcessOutput, error) {
// Assign values to region and endpoint if they haven't already been assigned
Expand All @@ -58,15 +73,12 @@ func GenerateCredentials(opts *CredentialsOpts, signer Signer, signatureAlgorith
opts.Region = trustAnchorArn.Region
}

mySession := session.Must(session.NewSession())

var logLevel aws.LogLevelType
var logMode aws.ClientLogMode = 0
if Debug {
logLevel = aws.LogDebug
} else {
logLevel = aws.LogOff
logMode = aws.LogSigning | aws.LogRetries | aws.LogRequestWithBody | aws.LogResponseWithBody | aws.LogRequestEventMessage | aws.LogResponseEventMessage
}

// Custom HTTP client with proxy and TLS settings
var tr *http.Transport
if opts.WithProxy {
tr = &http.Transport{
Expand All @@ -78,15 +90,26 @@ func GenerateCredentials(opts *CredentialsOpts, signer Signer, signatureAlgorith
TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: opts.NoVerifySSL},
}
}
client := &http.Client{Transport: tr}
config := aws.NewConfig().WithRegion(opts.Region).WithHTTPClient(client).WithLogLevel(logLevel)
httpClient := &http.Client{Transport: tr}
ctx := context.TODO()
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(opts.Region), config.WithHTTPClient(httpClient), config.WithClientLogMode(logMode))
if err != nil {
return CredentialProcessOutput{}, err
}

// Override endpoint if specified
if opts.Endpoint != "" {
config.WithEndpoint(opts.Endpoint)
cfg.BaseEndpoint = aws.String(opts.Endpoint)
}
rolesAnywhereClient := rolesanywhere.New(mySession, config)
rolesAnywhereClient.Handlers.Build.RemoveByName("core.SDKVersionUserAgentHandler")
rolesAnywhereClient.Handlers.Build.PushBackNamed(request.NamedHandler{Name: "v4x509.CredHelperUserAgentHandler", Fn: request.MakeAddToUserAgentHandler("CredHelper", opts.Version, runtime.Version(), runtime.GOOS, runtime.GOARCH)})
rolesAnywhereClient.Handlers.Sign.Clear()

// Set a custom user agent
userAgentStr := fmt.Sprintf("CredHelper/%s (%s; %s; %s)", opts.Version, runtime.Version(), runtime.GOOS, runtime.GOARCH)
cfg.APIOptions = append(cfg.APIOptions, func(stack *middleware.Stack) error {
stack.Build.Remove("UserAgent")
return stack.Build.Add(createCredHelperUserAgentMiddleware(userAgentStr), middleware.After)
})

// Add custom request signer, implementing SigV4-X509
certificate, err := signer.Certificate()
if err != nil {
return CredentialProcessOutput{}, errors.New("unable to find certificate")
Expand All @@ -98,10 +121,27 @@ func GenerateCredentials(opts *CredentialsOpts, signer Signer, signatureAlgorith
log.Println(err)
}
}
rolesAnywhereClient.Handlers.Sign.PushBackNamed(request.NamedHandler{Name: "v4x509.SignRequestHandler", Fn: CreateRequestSignFunction(signer, signatureAlgorithm, certificate, certificateChain)})
cfg.APIOptions = append(cfg.APIOptions, func(stack *middleware.Stack) error {
for _, name := range stack.Finalize.List() {
fmt.Println(name)
}
// Remove middleware related to SigV4 signing
stack.Finalize.Remove("Signing")
stack.Finalize.Remove("setLegacyContextSigningOptions")
stack.Finalize.Remove("GetIdentity")
// Add middleware for SigV4-X509 signing
stack.Finalize.Add(middleware.FinalizeMiddlewareFunc("Signing", CreateRequestSignFinalizeFunction(signer, opts.Region, signatureAlgorithm, certificate, certificateChain)), middleware.After)
for _, name := range stack.Finalize.List() {
fmt.Println(name)
}
return nil
})

// Create the Roles Anywhere client using the above-constructed Config
rolesAnywhereClient := rolesanywhere.NewFromConfig(cfg)

certificateStr := base64.StdEncoding.EncodeToString(certificate.Raw)
durationSeconds := int64(opts.SessionDuration)
durationSeconds := int32(opts.SessionDuration)
createSessionRequest := rolesanywhere.CreateSessionInput{
Cert: &certificateStr,
ProfileArn: &opts.ProfileArnStr,
Expand All @@ -114,7 +154,7 @@ func GenerateCredentials(opts *CredentialsOpts, signer Signer, signatureAlgorith
if opts.RoleSessionName != "" {
createSessionRequest.RoleSessionName = &opts.RoleSessionName
}
output, err := rolesAnywhereClient.CreateSession(&createSessionRequest)
output, err := rolesAnywhereClient.CreateSession(ctx, &createSessionRequest)
if err != nil {
return CredentialProcessOutput{}, err
}
Expand Down
96 changes: 36 additions & 60 deletions aws_signing_helper/signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package aws_signing_helper

import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/rand"
Expand All @@ -24,8 +25,9 @@ import (
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
"golang.org/x/crypto/pkcs12"
"golang.org/x/term"
)
Expand Down Expand Up @@ -57,6 +59,9 @@ var (
"Trust",
"CA",
}

// Signing name for the IAM Roles Anywhere service
ROLESANYWHERE_SIGNING_NAME = "rolesanywhere"
)

// Interface that all signers will have to implement
Expand Down Expand Up @@ -425,75 +430,46 @@ func certificateChainToString(certificateChain []*x509.Certificate) string {
return x509ChainString.String()
}

func CreateRequestSignFunction(signer crypto.Signer, signingAlgorithm string, certificate *x509.Certificate, certificateChain []*x509.Certificate) func(*request.Request) {
return func(req *request.Request) {
region := req.ClientInfo.SigningRegion
if region == "" {
region = aws.StringValue(req.Config.Region)
}

name := req.ClientInfo.SigningName
if name == "" {
name = req.ClientInfo.ServiceName
}

signerParams := SignerParams{time.Now(), region, name, signingAlgorithm}

// Set headers that are necessary for signing
req.HTTPRequest.Header.Set(host, req.HTTPRequest.URL.Host)
req.HTTPRequest.Header.Set(x_amz_date, signerParams.GetFormattedSigningDateTime())
req.HTTPRequest.Header.Set(x_amz_x509, certificateToString(certificate))
if certificateChain != nil {
req.HTTPRequest.Header.Set(x_amz_x509_chain, certificateChainToString(certificateChain))
}

contentSha256 := calculateContentHash(req.HTTPRequest, req.Body)
if req.HTTPRequest.Header.Get(x_amz_content_sha256) == "required" {
req.HTTPRequest.Header.Set(x_amz_content_sha256, contentSha256)
func CreateRequestSignFinalizeFunction(signer crypto.Signer, signingRegion string, signingAlgorithm string, certificate *x509.Certificate, certificateChain []*x509.Certificate) func(context.Context, middleware.FinalizeInput, middleware.FinalizeHandler) (middleware.FinalizeOutput, middleware.Metadata, error) {
return func(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (out middleware.FinalizeOutput, metadata middleware.Metadata, err error) {
req, ok := in.Request.(*smithyhttp.Request)
if !ok {
return out, metadata, errors.New(fmt.Sprintf("unexpected request middleware type %T", in.Request))
}

canonicalRequest, signedHeadersString := createCanonicalRequest(req.HTTPRequest, req.Body, contentSha256)
payloadHash := v4.GetPayloadHash(ctx)
signRequest(signer, signingRegion, signingAlgorithm, certificate, certificateChain, req.Request, payloadHash)

stringToSign := CreateStringToSign(canonicalRequest, signerParams)
signatureBytes, err := signer.Sign(rand.Reader, []byte(stringToSign), crypto.SHA256)
if err != nil {
log.Println("could not sign", err)
os.Exit(1)
}
signature := hex.EncodeToString(signatureBytes)

req.HTTPRequest.Header.Set(authorization, BuildAuthorizationHeader(req.HTTPRequest, req.Body, signedHeadersString, signature, certificate, signerParams))
req.SignedHeaderVals = req.HTTPRequest.Header
return next.HandleFinalize(ctx, in)
}
}

// Find the SHA256 hash of the provided request body as a io.ReadSeeker
func makeSha256Reader(reader io.ReadSeeker) []byte {
hash := sha256.New()
start, _ := reader.Seek(0, 1)
defer reader.Seek(start, 0)
func signRequest(signer crypto.Signer, signingRegion string, signingAlgorithm string, certificate *x509.Certificate, certificateChain []*x509.Certificate, req *http.Request, payloadHash string) {
signerParams := SignerParams{time.Now(), signingRegion, ROLESANYWHERE_SIGNING_NAME, signingAlgorithm}

io.Copy(hash, reader)
return hash.Sum(nil)
}
// Set headers that are necessary for signing
req.Header.Set(host, req.URL.Host)
req.Header.Set(x_amz_date, signerParams.GetFormattedSigningDateTime())
req.Header.Set(x_amz_x509, certificateToString(certificate))
if certificateChain != nil {
req.Header.Set(x_amz_x509_chain, certificateChainToString(certificateChain))
}

// Calculate the hash of the request body
func calculateContentHash(r *http.Request, body io.ReadSeeker) string {
hash := r.Header.Get(x_amz_content_sha256)
canonicalRequest, signedHeadersString := createCanonicalRequest(req, payloadHash)

if hash == "" {
if body == nil {
hash = emptyStringSHA256
} else {
hash = hex.EncodeToString(makeSha256Reader(body))
}
stringToSign := CreateStringToSign(canonicalRequest, signerParams)
signatureBytes, err := signer.Sign(rand.Reader, []byte(stringToSign), crypto.SHA256)
if err != nil {
log.Println("could not sign request", err)
os.Exit(1)
}
signature := hex.EncodeToString(signatureBytes)

return hash
req.Header.Set(authorization, BuildAuthorizationHeader(req, signedHeadersString, signature, certificate, signerParams))
}

// Create the canonical query string.
func createCanonicalQueryString(r *http.Request, body io.ReadSeeker) string {
func createCanonicalQueryString(r *http.Request) string {
rawQuery := strings.Replace(r.URL.Query().Encode(), "+", "%20", -1)
return rawQuery
}
Expand Down Expand Up @@ -573,14 +549,14 @@ func stripExcessSpaces(vals []string) {
}

// Create the canonical request.
func createCanonicalRequest(r *http.Request, body io.ReadSeeker, contentSha256 string) (string, string) {
func createCanonicalRequest(r *http.Request, contentSha256 string) (string, string) {
var canonicalRequestStrBuilder strings.Builder
canonicalHeaderString, signedHeadersString := createCanonicalHeaderString(r)
canonicalRequestStrBuilder.WriteString("POST")
canonicalRequestStrBuilder.WriteString("\n")
canonicalRequestStrBuilder.WriteString("/sessions")
canonicalRequestStrBuilder.WriteString("\n")
canonicalRequestStrBuilder.WriteString(createCanonicalQueryString(r, body))
canonicalRequestStrBuilder.WriteString(createCanonicalQueryString(r))
canonicalRequestStrBuilder.WriteString("\n")
canonicalRequestStrBuilder.WriteString(canonicalHeaderString)
canonicalRequestStrBuilder.WriteString("\n\n")
Expand All @@ -607,7 +583,7 @@ func CreateStringToSign(canonicalRequest string, signerParams SignerParams) stri
}

// Builds the complete authorization header
func BuildAuthorizationHeader(request *http.Request, body io.ReadSeeker, signedHeadersString string, signature string, certificate *x509.Certificate, signerParams SignerParams) string {
func BuildAuthorizationHeader(request *http.Request, signedHeadersString string, signature string, certificate *x509.Certificate, signerParams SignerParams) string {
signingCredentials := certificate.SerialNumber.String() + "/" + signerParams.GetScope()
credential := "Credential=" + signingCredentials
signerHeaders := "SignedHeaders=" + signedHeadersString
Expand Down
11 changes: 4 additions & 7 deletions aws_signing_helper/signer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ import (
"testing"
"time"
"unicode/utf8"

"github.com/aws/aws-sdk-go/aws/request"
)

const TestCredentialsFilePath = "/tmp/credentials"
Expand Down Expand Up @@ -125,7 +123,6 @@ func TestBuildAuthorizationHeader(t *testing.T) {
certificate1 := certificateList1[0]
pkPath := "../tst/certs/rsa-2048-key.pem"

awsRequest := request.Request{HTTPRequest: testRequest}
signer, signingAlgorithm, err := GetFileSystemSigner(pkPath, "", path, false)
if err != nil {
t.Log(err)
Expand All @@ -151,8 +148,9 @@ func TestBuildAuthorizationHeader(t *testing.T) {
t.Fail()
}
}
requestSignFunction := CreateRequestSignFunction(signer, signingAlgorithm, certificate, certificateChain)
requestSignFunction(&awsRequest)
signingRegion := "us-west-2"
emptyStringSHA256 := "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
signRequest(signer, signingRegion, signingAlgorithm, certificate, certificateChain, testRequest, emptyStringSHA256)

certificateList2, _ := ReadCertificateBundleData("../tst/certs/rsa-4096-sha256-cert.pem")
certificate2 := certificateList2[0]
Expand Down Expand Up @@ -180,8 +178,7 @@ func TestBuildAuthorizationHeader(t *testing.T) {
}
os.Rename("../tst/certs/rsa-2048-sha256-cert.pem", "../tst/certs/rsa-4096-sha256-cert.pem")
os.Rename("../tst/certs/rsa-2048-sha256-cert.pem.bak", "../tst/certs/rsa-2048-sha256-cert.pem")
requestSignFunction2 := CreateRequestSignFunction(signer, signingAlgorithm, certificate, certificateChain)
requestSignFunction2(&awsRequest)
signRequest(signer, signingRegion, signingAlgorithm, certificate, certificateChain, testRequest, emptyStringSHA256)
}

func TestSign(t *testing.T) {
Expand Down
15 changes: 13 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ go 1.22.5

require (
github.com/aws/aws-sdk-go v1.55.5
github.com/aws/aws-sdk-go-v2 v1.36.1
github.com/aws/aws-sdk-go-v2/config v1.29.6
github.com/aws/smithy-go v1.22.2
github.com/google/go-tpm v0.3.3
github.com/miekg/pkcs11 v1.1.1
github.com/spf13/cobra v1.8.1
Expand All @@ -15,8 +18,16 @@ require (
)

require (
github.com/aws/aws-sdk-go-v2/credentials v1.17.59 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.28 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.32 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.32 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.13 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.24.15 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.14 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.33.14 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
)
Loading