Skip to content

Commit

Permalink
fix the auth method
Browse files Browse the repository at this point in the history
  • Loading branch information
lihuahua123 authored Sep 14, 2023
1 parent d2fcfcd commit 1c02435
Showing 1 changed file with 83 additions and 27 deletions.
110 changes: 83 additions & 27 deletions storage/src/backend/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ struct BasicAuth {
}

#[derive(Debug, Clone)]
#[allow(dead_code)]
struct BearerAuth {
realm: String,
service: String,
Expand Down Expand Up @@ -188,6 +189,10 @@ struct RegistryState {
// Use RwLock here to avoid using mut backend trait object.
// Example: RwLock<"Bearer <token>">
// RwLock<"Basic base64(<username:password>)">

// 获取授权的方式,get | post
cached_get_auth_method: HashCache,

cached_auth: Cache,
// Cache 30X redirect url
// Example: RwLock<HashMap<"<blob_id>", "<redirected_url>">>
Expand Down Expand Up @@ -238,12 +243,51 @@ impl RegistryState {
}
}

/// Request registry authentication server to get bearer token
// v6
fn get_token(&self, auth: BearerAuth, connection: &Arc<Connection>) -> Result<TokenResponse> {
// The information needed for getting token needs to be placed both in
// the query and in the body to be compatible with different registry
// implementations, which have been tested on these platforms:
// docker hub, harbor, github ghcr, aliyun acr.
let mut resp_token: Response;
if let Some(method) = self.cached_get_auth_method.get(&self.host) {
if method == "post" {
resp_token = self.get_token_with_oauth(&auth, connection)?;
} else {
resp_token = self.get_token_without_oauth(&auth, connection)?;
}
} else {
resp_token = self.get_token_without_oauth(&auth, connection)?;
if !resp_token.status().is_success() {
resp_token = self.get_token_with_oauth(&auth, connection)?;
self.cached_get_auth_method
.set(self.host.clone(), "post".to_string());
}
}

let ret: TokenResponse = resp_token.json().map_err(|e| {
einval!(format!(
"registry auth server response decode failed: {:?}",
e
))
})?;

if let Ok(now_timestamp) = SystemTime::now().duration_since(UNIX_EPOCH) {
self.token_expired_at
.store(Some(Arc::new(now_timestamp.as_secs() + ret.expires_in)));
debug!(
"cached bearer auth, next time: {}",
now_timestamp.as_secs() + ret.expires_in
);
}

// Cache bearer auth for refreshing token.
self.cached_bearer_auth.store(Some(Arc::new(auth)));

Ok(ret)
}

fn get_token_with_oauth(
&self,
auth: &BearerAuth,
connection: &Arc<Connection>,
) -> Result<Response> {
let query = [
("service", auth.service.as_str()),
("scope", auth.scope.as_str()),
Expand All @@ -259,39 +303,49 @@ impl RegistryState {
}

let mut headers = HeaderMap::new();
if let Some(auth_header) = &auth.header {
headers.insert(HEADER_AUTHORIZATION, auth_header.clone());
}

let token_resp = connection
.call::<&[u8]>(
Method::POST,
auth.realm.as_str(),
None, //应该是不需要
Some(ReqBody::Form(form)),
&mut headers, //同上,应当不需要
true,
)
.map_err(|e| einval!(format!("registry auth server request failed {:?}", e)))?;

Ok(token_resp)
}

fn get_token_without_oauth(
&self,
auth: &BearerAuth,
connection: &Arc<Connection>,
) -> Result<Response> {
let query = [
("service", auth.service.as_str()),
("scope", auth.scope.as_str()),
("grant_type", "password"),
("username", self.username.as_str()),
("password", self.password.as_str()),
("client_id", REGISTRY_CLIENT_ID),
];

let mut headers = HeaderMap::new();

let token_resp = connection
.call::<&[u8]>(
Method::GET,
auth.realm.as_str(),
Some(&query),
Some(ReqBody::Form(form)),
None, //应当为空
&mut headers,
true,
)
.map_err(|e| einval!(format!("registry auth server request failed {:?}", e)))?;
let ret: TokenResponse = token_resp.json().map_err(|e| {
einval!(format!(
"registry auth server response decode failed: {:?}",
e
))
})?;
if let Ok(now_timestamp) = SystemTime::now().duration_since(UNIX_EPOCH) {
self.token_expired_at
.store(Some(Arc::new(now_timestamp.as_secs() + ret.expires_in)));
debug!(
"cached bearer auth, next time: {}",
now_timestamp.as_secs() + ret.expires_in
);
}

// Cache bearer auth for refreshing token.
self.cached_bearer_auth.store(Some(Arc::new(auth)));

Ok(ret)
Ok(token_resp)
}

fn get_auth_header(&self, auth: Auth, connection: &Arc<Connection>) -> Result<String> {
Expand Down Expand Up @@ -809,6 +863,7 @@ impl Registry {
retry_limit,
blob_url_scheme: config.blob_url_scheme.clone(),
blob_redirected_host: config.blob_redirected_host.clone(),
cached_get_auth_method: HashCache::new(),
cached_redirect: HashCache::new(),
token_expired_at: ArcSwapOption::new(None),
cached_bearer_auth: ArcSwapOption::new(None),
Expand Down Expand Up @@ -990,6 +1045,7 @@ mod tests {
retry_limit: 5,
blob_url_scheme: "https".to_string(),
blob_redirected_host: "oss.alibaba-inc.com".to_string(),
cached_get_auth_method: Default::default(),
cached_auth: Default::default(),
cached_redirect: Default::default(),
token_expired_at: ArcSwapOption::new(None),
Expand Down

0 comments on commit 1c02435

Please sign in to comment.