diff --git a/aws/credentials.go b/aws/credentials.go index eb42ea8debf..999807db8da 100644 --- a/aws/credentials.go +++ b/aws/credentials.go @@ -90,7 +90,9 @@ type Credentials struct { // Expired returns if the credentials have expired. func (v Credentials) Expired() bool { if v.CanExpire { - return !v.Expires.After(sdk.NowTime()) + // Calling Round(0) on the current time will truncate the monotonic reading only. Ensures credential expiry + // time is always based on reported wall-clock time. + return !v.Expires.After(sdk.NowTime().Round(0)) } return false diff --git a/ec2imds/token_provider.go b/ec2imds/token_provider.go index e8755f9525e..dea621a9e0b 100644 --- a/ec2imds/token_provider.go +++ b/ec2imds/token_provider.go @@ -48,7 +48,9 @@ var timeNow = time.Now // Expired returns if the token is expired. func (t *apiToken) Expired() bool { - return timeNow().After(t.expires) + // Calling Round(0) on the current time will truncate the monotonic reading only. Ensures credential expiry + // time is always based on reported wall-clock time. + return timeNow().Round(0).After(t.expires) } func (t *tokenProvider) ID() string { return "APITokenProvider" } @@ -112,7 +114,7 @@ func (t *tokenProvider) HandleDeserialize( return out, metadata, fmt.Errorf("expect HTTP transport, got %T", out.RawResponse) } - if resp.StatusCode == 401 { // unauthorized + if resp.StatusCode == http.StatusUnauthorized { // unauthorized err = &retryableError{Err: err} t.enable() } @@ -228,5 +230,8 @@ func (t *tokenProvider) disable() { // enable enables the token provide to start refreshing tokens, and adding them // to the pending request. func (t *tokenProvider) enable() { + t.tokenMux.Lock() + t.token = nil + t.tokenMux.Unlock() atomic.StoreUint32(&t.disabled, 0) }