Skip to content
This repository has been archived by the owner on Jul 6, 2024. It is now read-only.

Commit

Permalink
refactor: Reduce code dedup for auto refresh sequence
Browse files Browse the repository at this point in the history
Add `chain_err` and `map_err` which allows us write sequences of request
with custom actions on errors.
  • Loading branch information
Leander Beernaert committed Nov 17, 2023
1 parent fac21b5 commit 7c8ff20
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 143 deletions.
196 changes: 74 additions & 122 deletions src/clientv2/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,14 @@ use crate::domain::{
TwoFactorAuth, User, UserUid,
};
use crate::http;
use crate::http::{
ClientAsync, ClientRequest, ClientRequestBuilder, ClientSync, FromResponse, Request,
RequestDesc, Sequence, SequenceFromState, X_PM_UID_HEADER,
};
use crate::http::{OwnedRequest, RequestDesc, Sequence, SequenceFromState, X_PM_UID_HEADER};
use crate::requests::{
AuthInfoRequest, AuthInfoResponse, AuthRefreshRequest, AuthRequest, AuthResponse,
GetEventRequest, GetLabelsRequest, GetLatestEventRequest, LogoutRequest, TFAStatus,
TOTPRequest, UserAuth, UserInfoRequest,
};
use go_srp::SRPAuth;
use secrecy::{ExposeSecret, Secret};
use std::future::Future;
#[cfg(not(feature = "async-traits"))]
use std::pin::Pin;
use std::sync::Arc;

#[derive(Debug, thiserror::Error)]
Expand Down Expand Up @@ -87,8 +81,12 @@ impl Session {
SequenceFromState::new(state, login_sequence_1)
}

pub fn submit_totp(&self, code: &str) -> impl Sequence<Output = (), Error = http::Error> {
self.wrap_request(TOTPRequest::new(code).to_request())
pub fn submit_totp<'a>(
&'a self,
code: &'a str,
) -> impl Sequence<Output = (), Error = http::Error> + 'a {
//self.wrap_request(TOTPRequest::new(code).to_request())
self.wrap_request2(TOTPRequest::new(code))
}

pub fn refresh<'a>(
Expand All @@ -103,22 +101,31 @@ impl Session {
})
}

pub fn get_user(&self) -> impl Sequence<Output = User> {
self.wrap_request(UserInfoRequest {}.to_request())
pub fn get_user(&self) -> impl Sequence<Output = User> + '_ {
//self.wrap_request(UserInfoRequest {}.to_request())
// .map(|r| -> Result<User, http::Error> { Ok(r.user) })
self.wrap_request2(UserInfoRequest {})
.map(|r| -> Result<User, http::Error> { Ok(r.user) })
}

pub fn logout(&self) -> impl Sequence<Output = (), Error = http::Error> {
self.wrap_request(LogoutRequest {}.to_request())
pub fn logout(&self) -> impl Sequence<Output = (), Error = http::Error> + '_ {
//self.wrap_request(LogoutRequest {}.to_request())
self.wrap_request2(LogoutRequest {})
}

pub fn get_latest_event(&self) -> impl Sequence<Output = EventId, Error = http::Error> {
self.wrap_request(GetLatestEventRequest {}.to_request())
pub fn get_latest_event(&self) -> impl Sequence<Output = EventId, Error = http::Error> + '_ {
//self.wrap_request(GetLatestEventRequest {}.to_request())
// .map(|r| Ok(r.event_id))
self.wrap_request2(GetLatestEventRequest {})
.map(|r| Ok(r.event_id))
}

pub fn get_event(&self, id: &EventId) -> impl Sequence<Output = Event, Error = http::Error> {
self.wrap_request(GetEventRequest::new(id).to_request())
pub fn get_event<'a, 'b: 'a>(
&'b self,
id: &'a EventId,
) -> impl Sequence<Output = Event, Error = http::Error> + 'a {
//self.wrap_request(GetEventRequest::new(id).to_request())
self.wrap_request2(GetEventRequest::new(id))
}

pub fn get_refresh_data(&self) -> SessionRefreshData {
Expand All @@ -132,14 +139,19 @@ impl Session {
pub fn get_labels(
&self,
label_type: LabelType,
) -> impl Sequence<Output = Vec<Label>, Error = http::Error> {
self.wrap_request(GetLabelsRequest::new(label_type).to_request())
) -> impl Sequence<Output = Vec<Label>, Error = http::Error> + '_ {
//self.wrap_request(GetLabelsRequest::new(label_type).to_request())
// .map(|r| Ok(r.labels))
self.wrap_request2(GetLabelsRequest::new(label_type))
.map(|r| Ok(r.labels))
}

#[inline(always)]
fn wrap_request<R: Request>(&self, r: R) -> SessionRequest<R> {
SessionRequest(r, self.user_auth.clone())
fn wrap_request2<'a, 'b: 'a, R: RequestDesc + 'a>(
&'b self,
r: R,
) -> impl Sequence<Output = R::Output, Error = http::Error> + 'a {
SequenceFromState::new(self, move |s| wrap_session_request(s, r))
}
}

Expand Down Expand Up @@ -176,107 +188,6 @@ fn map_human_verification_err(e: LoginError) -> LoginError {
e
}

pub struct SessionRequest<R: Request>(R, Arc<parking_lot::RwLock<UserAuth>>);

impl<R: Request> SessionRequest<R> {
fn refresh_auth(&self) -> impl Sequence<Output = (), Error = http::Error> + '_ {
let reader = self.1.read();
AuthRefreshRequest::new(
reader.uid.expose_secret(),
reader.refresh_token.expose_secret(),
)
.to_request()
.map(|resp| {
let mut writer = self.1.write();
*writer = UserAuth::from_auth_refresh_response(resp);
Ok(())
})
}

async fn exec_async_impl<'a, C: ClientAsync, F: FromResponse>(
&'a self,
client: &'a C,
) -> Result<F::Output, http::Error> {
let v = self.build(client);
match client.execute_async::<F>(v).await {
Ok(r) => Ok(r),
Err(original_error) => {
if let http::Error::API(api_err) = &original_error {
if api_err.http_code == 401 {
log::debug!("Account session expired, attempting refresh");
// Session expired/not authorized, try auth refresh.
if let Err(e) = self.refresh_auth().do_async(client).await {
log::error!("Failed to refresh account {e}");
return Err(original_error);
}

// Execute request again
return client.execute_async::<F>(self.build(client)).await;
}
}
Err(original_error)
}
}
}
}

impl<R: Request> Request for SessionRequest<R> {
type Response = R::Response;

fn build<C: ClientRequestBuilder>(&self, builder: &C) -> C::Request {
let r = self.0.build(builder);
let borrow = self.1.read();
r.header(X_PM_UID_HEADER, borrow.uid.expose_secret().as_str())
.bearer_token(borrow.access_token.expose_secret())
}

fn exec_sync<T: ClientSync>(
&self,
client: &T,
) -> Result<<Self::Response as FromResponse>::Output, http::Error> {
match client.execute::<Self::Response>(self.build(client)) {
Ok(r) => Ok(r),
Err(original_error) => {
if let http::Error::API(api_err) = &original_error {
if api_err.http_code == 401 {
log::debug!("Account session expired, attempting refresh");
// Session expired/not authorized, try auth refresh.
if let Err(e) = self.refresh_auth().do_sync(client) {
log::error!("Failed to refresh account {e}");
return Err(original_error);
}

// Execute request again
return client.execute::<Self::Response>(self.build(client));
}
}
Err(original_error)
}
}
}

#[cfg(not(feature = "async-traits"))]
fn exec_async<'a, T: ClientAsync>(
&'a self,
client: &'a T,
) -> Pin<
Box<
dyn Future<Output = Result<<Self::Response as FromResponse>::Output, http::Error>> + 'a,
>,
> {
Box::pin(async move { self.exec_async_impl::<T, R::Response>(client).await })
}

#[cfg(feature = "async-traits")]
fn exec_async<'a, T: ClientAsync>(
&'a self,
client: &'a T,
) -> impl Future<Output = Result<<Self::Response as FromResponse>::Output, http::Error>> + 'a
{
async { self.exec_async_impl::<T, R::Response>(client).await }
}
}

struct State<'a> {
username: &'a str,
password: &'a SecretString,
Expand Down Expand Up @@ -336,3 +247,44 @@ fn login_sequence_1(st: State) -> impl Sequence<Output = SessionType, Error = Lo
.map(move |auth_info_response| generate_login_state(st, auth_info_response))
.state(login_sequence_2)
}

fn wrap_session_request<'a, R: RequestDesc + 'a>(
session: &'a Session,
r: R,
) -> impl Sequence<Output = R::Output, Error = http::Error> + 'a {
let data = {
let borrow = session.user_auth.read();
r.build()
.header(X_PM_UID_HEADER, borrow.uid.expose_secret().as_str())
.bearer_token(borrow.access_token.expose_secret())
};

// While we clone headers and url, the body clone is handled efficiently.
OwnedRequest::<R::Response>::new(data.clone()).chain_err(move |e| {
if let http::Error::API(api_err) = &e {
if api_err.http_code == 401 {
log::debug!("Account session expired, attempting refresh");
let borrow = session.user_auth.read();
return Ok(AuthRefreshRequest::new(
borrow.uid.expose_secret(),
borrow.refresh_token.expose_secret(),
)
.to_request()
.chain(move |resp| {
{
let mut writer = session.user_auth.write();
*writer = UserAuth::from_auth_refresh_response(resp);
}
let data = {
let borrow = session.user_auth.read();
data.header(X_PM_UID_HEADER, borrow.uid.expose_secret().as_str())
.bearer_token(borrow.access_token.expose_secret())
};
Ok(OwnedRequest::<R::Response>::new(data))
}));
}
}

Err(e)
})
}
7 changes: 5 additions & 2 deletions src/clientv2/totp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@ use crate::http::Sequence;
pub struct TotpSession(pub(super) Session);

impl TotpSession {
pub fn submit_totp(&self, code: &str) -> impl Sequence<Output = Session, Error = http::Error> {
pub fn submit_totp<'a>(
&'a self,
code: &'a str,
) -> impl Sequence<Output = Session, Error = http::Error> + 'a {
let auth = self.0.user_auth.clone();
self.0
.submit_totp(code)
.map(move |_| Ok(Session { user_auth: auth }))
}

pub fn logout(&self) -> impl Sequence<Output = ()> {
pub fn logout(&self) -> impl Sequence<Output = ()> + '_ {
self.0.logout()
}
}
26 changes: 18 additions & 8 deletions src/http/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::marker::PhantomData;
use std::pin::Pin;

/// HTTP Request representation.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct RequestData {
#[allow(unused)] // Only used by http implementations.
pub(super) method: Method,
Expand All @@ -33,8 +33,8 @@ impl RequestData {
self
}

pub fn bearer_token(self, token: &str) -> Self {
self.header("authorization", format!("Bearer {token}"))
pub fn bearer_token(self, token: impl AsRef<str>) -> Self {
self.header("authorization", format!("Bearer {}", token.as_ref()))
}

pub fn bytes(mut self, bytes: impl Into<Bytes>) -> Self {
Expand All @@ -58,16 +58,26 @@ pub trait RequestDesc {
type Response: FromResponse<Output = Self::Output>;

fn build(&self) -> RequestData;
fn to_request(&self) -> OwnedRequest<Self::Response> {
OwnedRequest(self.build(), PhantomData)
}
}

pub struct OwnedRequest<F: FromResponse>(RequestData, PhantomData<F>);

fn to_request(&self) -> RequestWrapper<Self::Response> {
let data = self.build();
RequestWrapper(data, PhantomData)
impl<F: FromResponse> OwnedRequest<F> {
pub fn new(r: RequestData) -> Self {
Self(r, PhantomData)
}
}

pub struct RequestWrapper<F: FromResponse>(RequestData, PhantomData<F>);
impl<R: RequestDesc> From<R> for OwnedRequest<R::Response> {
fn from(value: R) -> Self {
Self(value.build(), PhantomData)
}
}

impl<F: FromResponse> Request for RequestWrapper<F> {
impl<F: FromResponse> Request for OwnedRequest<F> {
type Response = F;

fn build<C: ClientRequestBuilder>(&self, builder: &C) -> C::Request {
Expand Down
Loading

0 comments on commit 7c8ff20

Please sign in to comment.