Skip to content

Commit

Permalink
Merge pull request #70 from qaspen-python/feature/require_sslmode_fixes
Browse files Browse the repository at this point in the history
Fixed problem with require sslmode
  • Loading branch information
chandr-andr authored Jul 21, 2024
2 parents 2f3f5c1 + 37feaf9 commit 8bf5393
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 5 deletions.
22 changes: 22 additions & 0 deletions python/tests/test_ssl_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,25 @@ async def test_ssl_mode_require_pool_builder(
pool = builder.build()

await pool.execute("SELECT 1")


async def test_ssl_mode_require_without_ca_file(
postgres_host: str,
postgres_user: str,
postgres_password: str,
postgres_port: int,
postgres_dbname: str,
) -> None:
builder = (
ConnectionPoolBuilder()
.max_pool_size(10)
.host(postgres_host)
.port(postgres_port)
.user(postgres_user)
.password(postgres_password)
.dbname(postgres_dbname)
.ssl_mode(SslMode.Require)
)
pool = builder.build()

await pool.execute("SELECT 1")
2 changes: 1 addition & 1 deletion src/driver/common_options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl TargetSessionAttrs {
}

#[pyclass]
#[derive(Clone, Copy)]
#[derive(Clone, Copy, PartialEq)]
pub enum SslMode {
/// Do not use TLS.
Disable,
Expand Down
13 changes: 11 additions & 2 deletions src/driver/connection_pool.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::runtime::tokio_runtime;
use deadpool_postgres::{Manager, ManagerConfig, Object, Pool, RecyclingMethod};
use openssl::ssl::{SslConnector, SslMethod};
use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
use postgres_openssl::MakeTlsConnector;
use pyo3::{pyclass, pyfunction, pymethods, PyAny};
use std::{sync::Arc, vec};
Expand All @@ -13,7 +13,7 @@ use crate::{
};

use super::{
common_options::{ConnRecyclingMethod, LoadBalanceHosts, SslMode, TargetSessionAttrs},
common_options::{self, ConnRecyclingMethod, LoadBalanceHosts, SslMode, TargetSessionAttrs},
connection::Connection,
utils::build_connection_config,
};
Expand Down Expand Up @@ -104,6 +104,15 @@ pub fn connect(
builder.set_ca_file(ca_file)?;
let tls_connector = MakeTlsConnector::new(builder.build());
mgr = Manager::from_config(pg_config, tls_connector, mgr_config);
} else if let Some(ssl_mode) = ssl_mode {
if ssl_mode == common_options::SslMode::Require {
let mut builder = SslConnector::builder(SslMethod::tls())?;
builder.set_verify(SslVerifyMode::NONE);
let tls_connector = MakeTlsConnector::new(builder.build());
mgr = Manager::from_config(pg_config, tls_connector, mgr_config);
} else {
mgr = Manager::from_config(pg_config, NoTls, mgr_config);
}
} else {
mgr = Manager::from_config(pg_config, NoTls, mgr_config);
}
Expand Down
16 changes: 14 additions & 2 deletions src/driver/connection_pool_builder.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
use std::{net::IpAddr, time::Duration};

use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod};
use openssl::ssl::{SslConnector, SslMethod};
use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
use postgres_openssl::MakeTlsConnector;
use pyo3::{pyclass, pymethods, Py, Python};
use tokio_postgres::NoTls;

use crate::exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult};

use super::connection_pool::ConnectionPool;
use super::{common_options, connection_pool::ConnectionPool};

#[pyclass]
pub struct ConnectionPoolBuilder {
config: tokio_postgres::Config,
max_db_pool_size: Option<usize>,
conn_recycling_method: Option<RecyclingMethod>,
ca_file: Option<String>,
ssl_mode: Option<common_options::SslMode>,
}

#[pymethods]
Expand All @@ -28,6 +29,7 @@ impl ConnectionPoolBuilder {
max_db_pool_size: Some(2),
conn_recycling_method: None,
ca_file: None,
ssl_mode: None,
}
}

Expand All @@ -53,6 +55,15 @@ impl ConnectionPoolBuilder {
builder.set_ca_file(ca_file)?;
let tls_connector = MakeTlsConnector::new(builder.build());
mgr = Manager::from_config(self.config.clone(), tls_connector, mgr_config);
} else if let Some(ssl_mode) = self.ssl_mode {
if ssl_mode == common_options::SslMode::Require {
let mut builder = SslConnector::builder(SslMethod::tls())?;
builder.set_verify(SslVerifyMode::NONE);
let tls_connector = MakeTlsConnector::new(builder.build());
mgr = Manager::from_config(self.config.clone(), tls_connector, mgr_config);
} else {
mgr = Manager::from_config(self.config.clone(), NoTls, mgr_config);
}
} else {
mgr = Manager::from_config(self.config.clone(), NoTls, mgr_config);
}
Expand Down Expand Up @@ -167,6 +178,7 @@ impl ConnectionPoolBuilder {
pub fn ssl_mode(self_: Py<Self>, ssl_mode: crate::driver::common_options::SslMode) -> Py<Self> {
Python::with_gil(|gil| {
let mut self_ = self_.borrow_mut(gil);
self_.ssl_mode = Some(ssl_mode);
self_.config.ssl_mode(ssl_mode.to_internal());
});
self_
Expand Down

0 comments on commit 8bf5393

Please sign in to comment.