diff --git a/python/tests/test_ssl_mode.py b/python/tests/test_ssl_mode.py index ae97900f..99f6b9b6 100644 --- a/python/tests/test_ssl_mode.py +++ b/python/tests/test_ssl_mode.py @@ -74,3 +74,25 @@ async def test_ssl_mode_require_pool_builder( pool = builder.build() await pool.execute("SELECT 1") + + +async def test_ssl_mode_require_without_ca_file( + postgres_host: str, + postgres_user: str, + postgres_password: str, + postgres_port: int, + postgres_dbname: str, +) -> None: + builder = ( + ConnectionPoolBuilder() + .max_pool_size(10) + .host(postgres_host) + .port(postgres_port) + .user(postgres_user) + .password(postgres_password) + .dbname(postgres_dbname) + .ssl_mode(SslMode.Require) + ) + pool = builder.build() + + await pool.execute("SELECT 1") diff --git a/src/driver/common_options.rs b/src/driver/common_options.rs index ee58a161..aaf9329a 100644 --- a/src/driver/common_options.rs +++ b/src/driver/common_options.rs @@ -64,7 +64,7 @@ impl TargetSessionAttrs { } #[pyclass] -#[derive(Clone, Copy)] +#[derive(Clone, Copy, PartialEq)] pub enum SslMode { /// Do not use TLS. Disable, diff --git a/src/driver/connection_pool.rs b/src/driver/connection_pool.rs index 5d5de42d..303fffc4 100644 --- a/src/driver/connection_pool.rs +++ b/src/driver/connection_pool.rs @@ -1,6 +1,6 @@ use crate::runtime::tokio_runtime; use deadpool_postgres::{Manager, ManagerConfig, Object, Pool, RecyclingMethod}; -use openssl::ssl::{SslConnector, SslMethod}; +use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; use postgres_openssl::MakeTlsConnector; use pyo3::{pyclass, pyfunction, pymethods, PyAny}; use std::{sync::Arc, vec}; @@ -13,7 +13,7 @@ use crate::{ }; use super::{ - common_options::{ConnRecyclingMethod, LoadBalanceHosts, SslMode, TargetSessionAttrs}, + common_options::{self, ConnRecyclingMethod, LoadBalanceHosts, SslMode, TargetSessionAttrs}, connection::Connection, utils::build_connection_config, }; @@ -104,6 +104,15 @@ pub fn connect( builder.set_ca_file(ca_file)?; let tls_connector = MakeTlsConnector::new(builder.build()); mgr = Manager::from_config(pg_config, tls_connector, mgr_config); + } else if let Some(ssl_mode) = ssl_mode { + if ssl_mode == common_options::SslMode::Require { + let mut builder = SslConnector::builder(SslMethod::tls())?; + builder.set_verify(SslVerifyMode::NONE); + let tls_connector = MakeTlsConnector::new(builder.build()); + mgr = Manager::from_config(pg_config, tls_connector, mgr_config); + } else { + mgr = Manager::from_config(pg_config, NoTls, mgr_config); + } } else { mgr = Manager::from_config(pg_config, NoTls, mgr_config); } diff --git a/src/driver/connection_pool_builder.rs b/src/driver/connection_pool_builder.rs index edc6dd7e..bb2047b4 100644 --- a/src/driver/connection_pool_builder.rs +++ b/src/driver/connection_pool_builder.rs @@ -1,14 +1,14 @@ use std::{net::IpAddr, time::Duration}; use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod}; -use openssl::ssl::{SslConnector, SslMethod}; +use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; use postgres_openssl::MakeTlsConnector; use pyo3::{pyclass, pymethods, Py, Python}; use tokio_postgres::NoTls; use crate::exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}; -use super::connection_pool::ConnectionPool; +use super::{common_options, connection_pool::ConnectionPool}; #[pyclass] pub struct ConnectionPoolBuilder { @@ -16,6 +16,7 @@ pub struct ConnectionPoolBuilder { max_db_pool_size: Option, conn_recycling_method: Option, ca_file: Option, + ssl_mode: Option, } #[pymethods] @@ -28,6 +29,7 @@ impl ConnectionPoolBuilder { max_db_pool_size: Some(2), conn_recycling_method: None, ca_file: None, + ssl_mode: None, } } @@ -53,6 +55,15 @@ impl ConnectionPoolBuilder { builder.set_ca_file(ca_file)?; let tls_connector = MakeTlsConnector::new(builder.build()); mgr = Manager::from_config(self.config.clone(), tls_connector, mgr_config); + } else if let Some(ssl_mode) = self.ssl_mode { + if ssl_mode == common_options::SslMode::Require { + let mut builder = SslConnector::builder(SslMethod::tls())?; + builder.set_verify(SslVerifyMode::NONE); + let tls_connector = MakeTlsConnector::new(builder.build()); + mgr = Manager::from_config(self.config.clone(), tls_connector, mgr_config); + } else { + mgr = Manager::from_config(self.config.clone(), NoTls, mgr_config); + } } else { mgr = Manager::from_config(self.config.clone(), NoTls, mgr_config); } @@ -167,6 +178,7 @@ impl ConnectionPoolBuilder { pub fn ssl_mode(self_: Py, ssl_mode: crate::driver::common_options::SslMode) -> Py { Python::with_gil(|gil| { let mut self_ = self_.borrow_mut(gil); + self_.ssl_mode = Some(ssl_mode); self_.config.ssl_mode(ssl_mode.to_internal()); }); self_