Skip to content

Commit

Permalink
feat: support vec/array impl for DnsResolver (#368)
Browse files Browse the repository at this point in the history
Closes #332

Co-authored-by: parkma99 <park-ma@hotmail.com>
Co-authored-by: Glen De Cauwsemaecker <contact@glendc.com>
  • Loading branch information
3 people authored Dec 28, 2024
1 parent a610feb commit b890732
Show file tree
Hide file tree
Showing 4 changed files with 305 additions and 39 deletions.
169 changes: 169 additions & 0 deletions rama-dns/src/chain.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
use std::error::Error;
use std::net::{Ipv4Addr, Ipv6Addr};

use rama_net::address::Domain;

use crate::DnsResolver;

macro_rules! dns_resolver_chain_impl {
() => {
async fn ipv4_lookup(&self, domain: Domain) -> Result<Vec<Ipv4Addr>, Self::Error> {
let mut errors = Vec::new();
for resolver in self {
match resolver.ipv4_lookup(domain.clone()).await {
Ok(ipv4s) => return Ok(ipv4s),
Err(err) => errors.push(err.into()),
}
}
Err(errors)
}

async fn ipv6_lookup(&self, domain: Domain) -> Result<Vec<Ipv6Addr>, Self::Error> {
let mut errors = Vec::new();
for resolver in self {
match resolver.ipv6_lookup(domain.clone()).await {
Ok(ipv6s) => return Ok(ipv6s),
Err(err) => errors.push(err.into()),
}
}
Err(errors)
}
};
}

impl<R> DnsResolver for Vec<R>
where
R: DnsResolver + Send,
R::Error: Into<Box<dyn Error + Send + Sync>>,
{
type Error = Vec<Box<dyn Error + Send + Sync>>;

dns_resolver_chain_impl!();
}

impl<R, const N: usize> DnsResolver for [R; N]
where
R: DnsResolver + Send,
R::Error: Into<Box<dyn Error + Send + Sync>>,
{
type Error = Vec<Box<dyn Error + Send + Sync>>;

dns_resolver_chain_impl!();
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{DenyAllDns, InMemoryDns};
use rama_core::combinators::Either;
use std::net::{Ipv4Addr, Ipv6Addr};

#[tokio::test]
async fn test_empty_chain_vec() {
let v = Vec::<InMemoryDns>::new();
assert!(v
.ipv4_lookup(Domain::from_static("plabayo.tech"))
.await
.is_err());
assert!(v
.ipv6_lookup(Domain::from_static("plabayo.tech"))
.await
.is_err());
}

#[tokio::test]
async fn test_empty_chain_array() {
let a: [InMemoryDns; 0] = [];
assert!(a
.ipv4_lookup(Domain::from_static("plabayo.tech"))
.await
.is_err());
assert!(a
.ipv6_lookup(Domain::from_static("plabayo.tech"))
.await
.is_err());
}

#[tokio::test]
async fn test_chain_ok_err_ipv4() {
let mut dns = InMemoryDns::new();
dns.insert_addr(
Domain::from_static("example.com"),
Ipv4Addr::new(127, 0, 0, 1),
);
let v = vec![Either::A(dns), Either::B(DenyAllDns::new())];

let result = v
.ipv4_lookup(Domain::from_static("example.com"))
.await
.unwrap();
assert_eq!(result[0], Ipv4Addr::new(127, 0, 0, 1));
}

#[tokio::test]
async fn test_chain_err_ok_ipv6() {
let mut dns = InMemoryDns::new();
dns.insert_addr(
Domain::from_static("example.com"),
Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
);
let v = vec![Either::B(DenyAllDns::new()), Either::A(dns)];

let result = v
.ipv6_lookup(Domain::from_static("example.com"))
.await
.unwrap();
assert_eq!(result[0], Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1));
}

#[tokio::test]
async fn test_chain_ok_ok_ipv6() {
let mut dns1 = InMemoryDns::new();
let mut dns2 = InMemoryDns::new();
dns1.insert_addr(
Domain::from_static("example.com"),
Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
);
dns2.insert_addr(
Domain::from_static("example.com"),
Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2),
);

let v = vec![dns1, dns2];
let result = v
.ipv6_lookup(Domain::from_static("example.com"))
.await
.unwrap();
// Should return the first successful result
assert_eq!(result[0], Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1));
}

#[tokio::test]
async fn test_chain_err_err_ok_ipv4() {
let mut dns = InMemoryDns::new();
dns.insert_addr(
Domain::from_static("example.com"),
Ipv4Addr::new(127, 0, 0, 1),
);

let v = vec![
Either::B(DenyAllDns::new()),
Either::B(DenyAllDns::new()),
Either::A(dns),
];
let result = v
.ipv4_lookup(Domain::from_static("example.com"))
.await
.unwrap();
assert_eq!(result[0], Ipv4Addr::new(127, 0, 0, 1));
}

#[tokio::test]
async fn test_chain_err_err_ipv4() {
let v = vec![DenyAllDns::new(), DenyAllDns::new()];
assert!(v
.ipv4_lookup(Domain::from_static("example.com"))
.await
.is_err());
}
}
13 changes: 13 additions & 0 deletions rama-dns/src/in_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ pub struct InMemoryDns {
}

impl InMemoryDns {
/// Creates a new empty [`InMemoryDns`] instance.
pub fn new() -> Self {
Self { map: None }
}

/// Inserts a domain to IP address mapping to the [`InMemoryDns`].
///
/// Existing mappings will be overwritten.
Expand All @@ -59,6 +64,14 @@ impl InMemoryDns {
self
}

/// Insert an IP address for a domain.
///
/// This method accepts any type that can be converted into an `IpAddr`,
/// such as `Ipv4Addr` or `Ipv6Addr`.
pub fn insert_addr<A: Into<IpAddr>>(&mut self, name: Domain, addr: A) -> &mut Self {
self.insert(name.into(), vec![addr.into()])
}

/// Extend the [`InMemoryDns`] with the given mappings.
///
/// Existing mappings will be overwritten.
Expand Down
43 changes: 4 additions & 39 deletions rama-dns/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,45 +79,6 @@ impl<R: DnsResolver<Error: Into<BoxError>>> DnsResolver for Option<R> {
}
}

macro_rules! impl_dns_resolver_either_either {
($id:ident, $($param:ident),+ $(,)?) => {
impl<$($param),+> DnsResolver for ::rama_core::combinators::$id<$($param),+>
where
$($param: DnsResolver<Error: Into<::rama_core::error::BoxError>>),+,
{
type Error = ::rama_core::error::BoxError;

async fn ipv4_lookup(
&self,
domain: Domain,
) -> Result<Vec<Ipv4Addr>, Self::Error>{
match self {
$(
::rama_core::combinators::$id::$param(d) => d.ipv4_lookup(domain)
.await
.map_err(Into::into),
)+
}
}

async fn ipv6_lookup(
&self,
domain: Domain,
) -> Result<Vec<Ipv6Addr>, Self::Error> {
match self {
$(
::rama_core::combinators::$id::$param(d) => d.ipv6_lookup(domain)
.await
.map_err(Into::into),
)+
}
}
}
};
}

rama_core::combinators::impl_either!(impl_dns_resolver_either_either);

pub mod hickory;
#[doc(inline)]
pub use hickory::HickoryDns;
Expand All @@ -129,3 +90,7 @@ pub use in_memory::{DnsOverwrite, DomainNotMappedErr, InMemoryDns};
mod deny_all;
#[doc(inline)]
pub use deny_all::{DenyAllDns, DnsDeniedError};

pub mod chain;

mod variant;
119 changes: 119 additions & 0 deletions rama-dns/src/variant.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
use crate::DnsResolver;
use rama_net::address::Domain;
use std::net::{Ipv4Addr, Ipv6Addr};

macro_rules! impl_dns_resolver_either_either {
($id:ident, $($param:ident),+ $(,)?) => {
impl<$($param),+> DnsResolver for ::rama_core::combinators::$id<$($param),+>
where
$($param: DnsResolver<Error: Into<::rama_core::error::BoxError>>),+,
{
type Error = ::rama_core::error::BoxError;

async fn ipv4_lookup(
&self,
domain: Domain,
) -> Result<Vec<Ipv4Addr>, Self::Error>{
match self {
$(
::rama_core::combinators::$id::$param(d) => d.ipv4_lookup(domain)
.await
.map_err(Into::into),
)+
}
}

async fn ipv6_lookup(
&self,
domain: Domain,
) -> Result<Vec<Ipv6Addr>, Self::Error> {
match self {
$(
::rama_core::combinators::$id::$param(d) => d.ipv6_lookup(domain)
.await
.map_err(Into::into),
)+
}
}
}
};
}

rama_core::combinators::impl_either!(impl_dns_resolver_either_either);

#[cfg(test)]
mod tests {
use crate::DnsResolver;
use rama_core::combinators::Either;
use rama_net::address::Domain;
use std::future::Future;
use std::net::{Ipv4Addr, Ipv6Addr};

// Mock DNS resolvers for testing
struct MockResolver1;
struct MockResolver2;

impl DnsResolver for MockResolver1 {
type Error = Box<dyn std::error::Error + Send + Sync>;

fn ipv4_lookup(
&self,
_domain: Domain,
) -> impl Future<Output = Result<Vec<Ipv4Addr>, Self::Error>> {
std::future::ready(Ok(vec![Ipv4Addr::new(127, 0, 0, 1)]))
}

fn ipv6_lookup(
&self,
_domain: Domain,
) -> impl Future<Output = Result<Vec<Ipv6Addr>, Self::Error>> {
std::future::ready(Ok(vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)]))
}
}

impl DnsResolver for MockResolver2 {
type Error = Box<dyn std::error::Error + Send + Sync>;

fn ipv4_lookup(
&self,
_domain: Domain,
) -> impl Future<Output = Result<Vec<Ipv4Addr>, Self::Error>> + Send + '_ {
std::future::ready(Ok(vec![Ipv4Addr::new(192, 168, 1, 1)]))
}

fn ipv6_lookup(
&self,
_domain: Domain,
) -> impl Future<Output = Result<Vec<Ipv6Addr>, Self::Error>> + Send + '_ {
std::future::ready(Ok(vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2)]))
}
}

#[tokio::test]
async fn test_either_ipv4_lookup() {
let resolver1 = Either::<MockResolver1, MockResolver2>::A(MockResolver1);
let resolver2 = Either::<MockResolver1, MockResolver2>::B(MockResolver2);

let domain = "example.com".parse::<Domain>().unwrap();

let result1 = resolver1.ipv4_lookup(domain.clone()).await.unwrap();
assert_eq!(result1, vec![Ipv4Addr::new(127, 0, 0, 1)]);

let result2 = resolver2.ipv4_lookup(domain).await.unwrap();
assert_eq!(result2, vec![Ipv4Addr::new(192, 168, 1, 1)]);
}

#[tokio::test]
async fn test_either_ipv6_lookup() {
let resolver1 = Either::<MockResolver1, MockResolver2>::A(MockResolver1);
let resolver2 = Either::<MockResolver1, MockResolver2>::B(MockResolver2);

let domain = "example.com".parse::<Domain>().unwrap();

let result1 = resolver1.ipv6_lookup(domain.clone()).await.unwrap();
assert_eq!(result1, vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)]);

let result2 = resolver2.ipv6_lookup(domain).await.unwrap();
assert_eq!(result2, vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2)]);
}
}

0 comments on commit b890732

Please sign in to comment.