Skip to content

Commit

Permalink
feat!: Add update-secret to the connection (#155)
Browse files Browse the repository at this point in the history
One more function had to be added to the ConnectionCallback
  • Loading branch information
mzaniolo authored Feb 9, 2025
1 parent 30381c4 commit d350041
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 1 deletion.
11 changes: 11 additions & 0 deletions amqprs/src/api/callbacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ pub trait ConnectionCallback {

/// Callback to handle connection `unblocked` indication from server
async fn unblocked(&mut self, connection: &Connection);

/// Callback to handle secret updated indication from server
async fn secret_updated(&mut self, connection: &Connection);
}

/// Default type that implements `ConnectionCallback`.
Expand Down Expand Up @@ -87,6 +90,14 @@ impl ConnectionCallback for DefaultConnectionCallback {
connection
);
}

async fn secret_updated(&mut self, connection: &Connection) {
#[cfg(feature = "traces")]
info!(
"handle secret updated notification for connection {}",
connection
);
}
}

/////////////////////////////////////////////////////////////////////////////
Expand Down
29 changes: 28 additions & 1 deletion amqprs/src/api/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ use tokio::sync::{broadcast, mpsc, oneshot};
use crate::{
frame::{
Blocked, Close, CloseOk, Frame, MethodHeader, Open, OpenChannel, OpenChannelOk,
ProtocolHeader, StartOk, TuneOk, Unblocked, DEFAULT_CONN_CHANNEL, FRAME_MIN_SIZE,
ProtocolHeader, StartOk, TuneOk, Unblocked, UpdateSecret, UpdateSecretOk,
DEFAULT_CONN_CHANNEL, FRAME_MIN_SIZE,
},
net::{
ChannelResource, ConnManagementCommand, IncomingMessage, OutgoingMessage, ReaderHandler,
Expand Down Expand Up @@ -1244,6 +1245,32 @@ impl Connection {
let mut shutdown_listener = self.shared.shutdown_subscriber.subscribe();
(shutdown_listener.recv().await).unwrap_or(false)
}

/// Update the secret used by some authentication module such as OAuth2.
///
/// # Errors
///
/// Returns error if fails to send indication to server.
pub async fn update_secret(&self, new_secret: &str, reason: &str) -> Result<()> {
let responder_rx = self
.register_responder(DEFAULT_CONN_CHANNEL, UpdateSecretOk::header())
.await?;

let update_secret = UpdateSecret::new(
new_secret.to_owned().try_into().unwrap(),
reason.to_owned().try_into().unwrap(),
);

synchronous_request!(
self.shared.outgoing_tx,
(DEFAULT_CONN_CHANNEL, update_secret.into_frame()),
responder_rx,
Frame::UpdateSecretOk,
Error::UpdateSecretError
)?;

Ok(())
}
}

impl Drop for DropGuard {
Expand Down
3 changes: 3 additions & 0 deletions amqprs/src/api/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ pub enum Error {
/// Error in sending or receiving messages via internal communication channel.
/// Usually due to incorrect usage by user.
InternalChannelError(String),
/// Error during updating the secret
UpdateSecretError(String),
}

#[cfg(feature = "urispec")]
Expand Down Expand Up @@ -71,6 +73,7 @@ impl fmt::Display for Error {
Error::InternalChannelError(msg) => {
write!(f, "AMQP internal communication error: {}", msg)
}
Error::UpdateSecretError(msg) => write!(f, "AMQP update secret error: {}", msg),
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions amqprs/src/frame/method/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,5 +219,11 @@ pub struct UpdateSecret {
pub(crate) reason: ShortStr,
}

impl UpdateSecret {
pub fn new(new_secret: LongStr, reason: ShortStr) -> Self {
Self { new_secret, reason }
}
}

#[derive(Debug, Serialize, Deserialize, Default)]
pub struct UpdateSecretOk;
14 changes: 14 additions & 0 deletions amqprs/src/net/reader_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,20 @@ impl ReaderHandler {
}
Ok(())
}
Frame::UpdateSecretOk(method_header, update_secret_ok) => {
let responder = self
.channel_manager
.remove_responder(&channel_id, method_header)
.expect("responder must be registered");
responder
.send(update_secret_ok.into_frame())
.map_err(|err_frame| {
Error::SyncChannel(format!(
"failed to forward {} to connection {}",
err_frame, self.amqp_connection
))
})
}
// dispatch other frames to channel dispatcher
_ => {
let dispatcher = self.channel_manager.get_dispatcher(&channel_id);
Expand Down
34 changes: 34 additions & 0 deletions amqprs/tests/test_update_secret.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use amqprs::{
callbacks::{DefaultChannelCallback, DefaultConnectionCallback},
connection::Connection,
};
mod common;

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_update_secret() {
common::setup_logging();

// open a connection to RabbitMQ server
let args = common::build_conn_args();

let connection = Connection::open(&args).await.unwrap();
connection
.register_callback(DefaultConnectionCallback)
.await
.unwrap();
// open a channel on the connection
let channel = connection.open_channel(None).await.unwrap();
channel
.register_callback(DefaultChannelCallback)
.await
.unwrap();

connection
.update_secret("123456", "secret expired")
.await
.unwrap();

// close
channel.close().await.unwrap();
connection.close().await.unwrap();
}
1 change: 1 addition & 0 deletions examples/src/callbacks_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ impl ConnectionCallback for ExampleConnectionCallback {

async fn blocked(&mut self, connection: &Connection, reason: String) {}
async fn unblocked(&mut self, connection: &Connection) {}
async fn secret_updated(&mut self, connection: &Connection) {}
}

////////////////////////////////////////////////////////////////////////////////
Expand Down

0 comments on commit d350041

Please sign in to comment.