From ca1abfeb2dc16453df702097f13c956a27b4c9a4 Mon Sep 17 00:00:00 2001 From: Michael Sproul Date: Wed, 11 Oct 2023 10:43:49 +1100 Subject: [PATCH] Support iterables in compare_fields --- Cargo.lock | 1 + common/compare_fields/Cargo.toml | 3 +++ common/compare_fields/src/lib.rs | 36 ++++++++++++++++++++----- common/compare_fields_derive/src/lib.rs | 13 ++++----- 4 files changed, 40 insertions(+), 13 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 697cdc1796b..a7862f321c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1153,6 +1153,7 @@ name = "compare_fields" version = "0.2.0" dependencies = [ "compare_fields_derive", + "itertools", ] [[package]] diff --git a/common/compare_fields/Cargo.toml b/common/compare_fields/Cargo.toml index 8df989e7225..9972ca75ca6 100644 --- a/common/compare_fields/Cargo.toml +++ b/common/compare_fields/Cargo.toml @@ -4,6 +4,9 @@ version = "0.2.0" authors = ["Paul Hauner "] edition = { workspace = true } +[dependencies] +itertools = { workspace = true } + [dev-dependencies] compare_fields_derive = { workspace = true } diff --git a/common/compare_fields/src/lib.rs b/common/compare_fields/src/lib.rs index bc2f5446ad2..27baf148067 100644 --- a/common/compare_fields/src/lib.rs +++ b/common/compare_fields/src/lib.rs @@ -81,11 +81,8 @@ //! } //! ]; //! assert_eq!(bar_a.compare_fields(&bar_b), bar_a_b); -//! -//! -//! -//! // TODO: //! ``` +use itertools::{EitherOrBoth, Itertools}; use std::fmt::Debug; #[derive(Debug, PartialEq, Clone)] @@ -112,13 +109,38 @@ impl Comparison { } pub fn from_slice>(field_name: String, a: &[T], b: &[T]) -> Self { + Self::from_iter(field_name, a.iter(), b.iter()) + } + + pub fn from_into_iter<'a, T: Debug + PartialEq + 'a>( + field_name: String, + a: impl IntoIterator, + b: impl IntoIterator, + ) -> Self { + Self::from_iter(field_name, a.into_iter(), b.into_iter()) + } + + pub fn from_iter<'a, T: Debug + PartialEq + 'a>( + field_name: String, + a: impl Iterator, + b: impl Iterator, + ) -> Self { let mut children = vec![]; + let mut all_equal = true; - for i in 0..std::cmp::max(a.len(), b.len()) { - children.push(FieldComparison::new(format!("{i}"), &a.get(i), &b.get(i))); + for (i, entry) in a.zip_longest(b).enumerate() { + let comparison = match entry { + EitherOrBoth::Both(x, y) => { + FieldComparison::new(format!("{i}"), &Some(x), &Some(y)) + } + EitherOrBoth::Left(x) => FieldComparison::new(format!("{i}"), &Some(x), &None), + EitherOrBoth::Right(y) => FieldComparison::new(format!("{i}"), &None, &Some(y)), + }; + all_equal = all_equal && comparison.equal(); + children.push(comparison); } - Self::parent(field_name, a == b, children) + Self::parent(field_name, all_equal, children) } pub fn retain_children(&mut self, f: F) diff --git a/common/compare_fields_derive/src/lib.rs b/common/compare_fields_derive/src/lib.rs index a8b92b3d548..099db8e791e 100644 --- a/common/compare_fields_derive/src/lib.rs +++ b/common/compare_fields_derive/src/lib.rs @@ -4,10 +4,11 @@ use proc_macro::TokenStream; use quote::quote; use syn::{parse_macro_input, DeriveInput}; -fn is_slice(field: &syn::Field) -> bool { +fn is_iter(field: &syn::Field) -> bool { field.attrs.iter().any(|attr| { attr.path.is_ident("compare_fields") - && attr.tokens.to_string().replace(' ', "") == "(as_slice)" + && (attr.tokens.to_string().replace(' ', "") == "(as_slice)" + || attr.tokens.to_string().replace(' ', "") == "(as_iter)") }) } @@ -34,13 +35,13 @@ pub fn compare_fields_derive(input: TokenStream) -> TokenStream { let field_name = ident_a.to_string(); let ident_b = ident_a.clone(); - let quote = if is_slice(field) { + let quote = if is_iter(field) { quote! { - comparisons.push(compare_fields::Comparison::from_slice( + comparisons.push(compare_fields::Comparison::from_into_iter( #field_name.to_string(), &self.#ident_a, - &b.#ident_b) - ); + &b.#ident_b + )); } } else { quote! {