diff --git a/storage/src/backend/registry.rs b/storage/src/backend/registry.rs index 737a03454ff..a07090c9f42 100644 --- a/storage/src/backend/registry.rs +++ b/storage/src/backend/registry.rs @@ -93,19 +93,22 @@ impl Cache { } #[derive(Default)] -struct HashCache(RwLock>); +struct HashCache(RwLock>); -impl HashCache { +impl HashCache { fn new() -> Self { HashCache(RwLock::new(HashMap::new())) } - fn get(&self, key: &str) -> Option { + fn get(&self, key: &str) -> Option + where + T: Clone, + { let cached_guard = self.0.read().unwrap(); cached_guard.get(key).cloned() } - fn set(&self, key: String, value: String) { + fn set(&self, key: String, value: T) { let mut cached_guard = self.0.write().unwrap(); cached_guard.insert(key, value); } @@ -136,6 +139,7 @@ struct BasicAuth { } #[derive(Debug, Clone)] +#[allow(dead_code)] struct BearerAuth { realm: String, service: String, @@ -189,10 +193,14 @@ struct RegistryState { // Example: RwLock<"Bearer "> // RwLock<"Basic base64()"> cached_auth: Cache, + // Cache for the HTTP method when getting auth, it is "true" when using "GET" method. + // Due to the different implementations of various image registries, auth requests + // may use the GET or POST methods, we need to cache the method after the + // fallback, so it can be reused next time and reduce an unnecessary request. + cached_auth_using_http_get: HashCache, // Cache 30X redirect url // Example: RwLock", "">> - cached_redirect: HashCache, - + cached_redirect: HashCache, // The epoch timestamp of token expiration, which is obtained from the registry server. token_expired_at: ArcSwapOption, // Cache bearer auth for refreshing token. @@ -238,12 +246,85 @@ impl RegistryState { } } - /// Request registry authentication server to get bearer token + // Request registry authentication server to get bearer token fn get_token(&self, auth: BearerAuth, connection: &Arc) -> Result { - // The information needed for getting token needs to be placed both in - // the query and in the body to be compatible with different registry - // implementations, which have been tested on these platforms: - // docker hub, harbor, github ghcr, aliyun acr. + let http_get = self + .cached_auth_using_http_get + .get(&self.host) + .unwrap_or_default(); + let resp = if http_get { + self.get_token_with_get(&auth, connection)? + } else { + match self.get_token_with_post(&auth, connection) { + Ok(resp) => resp, + Err(err) => { + warn!("retry http GET method to get auth token: {}", err); + let resp = self.get_token_with_get(&auth, connection)?; + // Cache http method for next use. + self.cached_auth_using_http_get.set(self.host.clone(), true); + resp + } + } + }; + + let ret: TokenResponse = resp.json().map_err(|e| { + einval!(format!( + "registry auth server response decode failed: {:?}", + e + )) + })?; + + if let Ok(now_timestamp) = SystemTime::now().duration_since(UNIX_EPOCH) { + self.token_expired_at + .store(Some(Arc::new(now_timestamp.as_secs() + ret.expires_in))); + debug!( + "cached bearer auth, next time: {}", + now_timestamp.as_secs() + ret.expires_in + ); + } + + // Cache bearer auth for refreshing token. + self.cached_bearer_auth.store(Some(Arc::new(auth))); + + Ok(ret) + } + + // Get bearer token using a POST request + fn get_token_with_post( + &self, + auth: &BearerAuth, + connection: &Arc, + ) -> Result { + let mut form = HashMap::new(); + form.insert("service".to_string(), auth.service.clone()); + form.insert("scope".to_string(), auth.scope.clone()); + form.insert("grant_type".to_string(), "password".to_string()); + form.insert("username".to_string(), self.username.clone()); + form.insert("passward".to_string(), self.password.clone()); + form.insert("client_id".to_string(), REGISTRY_CLIENT_ID.to_string()); + + let mut headers = HeaderMap::new(); + + let token_resp = connection + .call::<&[u8]>( + Method::POST, + auth.realm.as_str(), + None, + Some(ReqBody::Form(form)), + &mut headers, + true, + ) + .map_err(|e| einval!(format!("registry auth server request failed {:?}", e)))?; + + Ok(token_resp) + } + + // Get bearer token using a GET request + fn get_token_with_get( + &self, + auth: &BearerAuth, + connection: &Arc, + ) -> Result { let query = [ ("service", auth.service.as_str()), ("scope", auth.scope.as_str()), @@ -253,45 +334,20 @@ impl RegistryState { ("client_id", REGISTRY_CLIENT_ID), ]; - let mut form = HashMap::new(); - for (k, v) in &query { - form.insert(k.to_string(), v.to_string()); - } - let mut headers = HeaderMap::new(); - if let Some(auth_header) = &auth.header { - headers.insert(HEADER_AUTHORIZATION, auth_header.clone()); - } let token_resp = connection .call::<&[u8]>( Method::GET, auth.realm.as_str(), Some(&query), - Some(ReqBody::Form(form)), + None, &mut headers, true, ) .map_err(|e| einval!(format!("registry auth server request failed {:?}", e)))?; - let ret: TokenResponse = token_resp.json().map_err(|e| { - einval!(format!( - "registry auth server response decode failed: {:?}", - e - )) - })?; - if let Ok(now_timestamp) = SystemTime::now().duration_since(UNIX_EPOCH) { - self.token_expired_at - .store(Some(Arc::new(now_timestamp.as_secs() + ret.expires_in))); - debug!( - "cached bearer auth, next time: {}", - now_timestamp.as_secs() + ret.expires_in - ); - } - // Cache bearer auth for refreshing token. - self.cached_bearer_auth.store(Some(Arc::new(auth))); - - Ok(ret) + Ok(token_resp) } fn get_auth_header(&self, auth: Auth, connection: &Arc) -> Result { @@ -809,6 +865,7 @@ impl Registry { retry_limit, blob_url_scheme: config.blob_url_scheme.clone(), blob_redirected_host: config.blob_redirected_host.clone(), + cached_auth_using_http_get: HashCache::new(), cached_redirect: HashCache::new(), token_expired_at: ArcSwapOption::new(None), cached_bearer_auth: ArcSwapOption::new(None), @@ -990,6 +1047,7 @@ mod tests { retry_limit: 5, blob_url_scheme: "https".to_string(), blob_redirected_host: "oss.alibaba-inc.com".to_string(), + cached_auth_using_http_get: Default::default(), cached_auth: Default::default(), cached_redirect: Default::default(), token_expired_at: ArcSwapOption::new(None),