Skip to content

Commit

Permalink
Rewrite the multi cartesian product iterator to both simplify it and …
Browse files Browse the repository at this point in the history
…fix a bug.
  • Loading branch information
JakobDegen committed Feb 8, 2022
1 parent 6c4fc2f commit 1a1d6cc
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 158 deletions.
295 changes: 137 additions & 158 deletions src/adaptors/multi_product.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#![cfg(feature = "use_alloc")]

use crate::size_hint;
use crate::Itertools;

use alloc::vec::Vec;

Expand All @@ -14,217 +13,197 @@ use alloc::vec::Vec;
/// See [`.multi_cartesian_product()`](crate::Itertools::multi_cartesian_product)
/// for more information.
#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
pub struct MultiProduct<I>(Vec<MultiProductIter<I>>)
where I: Iterator + Clone,
I::Item: Clone;
pub struct MultiProduct<I>
where
I: Iterator + Clone,
I::Item: Clone,
{
// The last thing we returned
state: MultiProductState<I::Item>,
iters: Vec<MultiProductIter<I>>,
}

impl<I> std::fmt::Debug for MultiProduct<I>
where
I: Iterator + Clone + std::fmt::Debug,
I::Item: Clone + std::fmt::Debug,
{
debug_fmt_fields!(CoalesceBy, 0);
debug_fmt_fields!(CoalesceBy, iters);
}

/// Stores the current state of the iterator.
#[derive(Clone)]
enum MultiProductState<I> {
/// In the middle of an iteration. The `Vec<I>` is the last value we returned
InProgress(Vec<I>),
/// At the beginning of an iteration. The `Vec<I>` is the next value to be returned.
Restarted(Vec<I>),
/// Iteration has not been started
Unstarted,
}
use MultiProductState::*;

/// Create a new cartesian product iterator over an arbitrary number
/// of iterators of the same type.
///
/// Iterator element is of type `Vec<H::Item::Item>`.
pub fn multi_cartesian_product<H>(iters: H) -> MultiProduct<<H::Item as IntoIterator>::IntoIter>
where H: Iterator,
H::Item: IntoIterator,
<H::Item as IntoIterator>::IntoIter: Clone,
<H::Item as IntoIterator>::Item: Clone
where
H: Iterator,
H::Item: IntoIterator,
<H::Item as IntoIterator>::IntoIter: Clone,
<H::Item as IntoIterator>::Item: Clone,
{
MultiProduct(iters.map(|i| MultiProductIter::new(i.into_iter())).collect())
MultiProduct {
state: MultiProductState::Unstarted,
iters: iters
.map(|i| MultiProductIter::new(i.into_iter()))
.collect(),
}
}

#[derive(Clone, Debug)]
/// Holds the state of a single iterator within a MultiProduct.
struct MultiProductIter<I>
where I: Iterator + Clone,
I::Item: Clone
where
I: Iterator + Clone,
I::Item: Clone,
{
cur: Option<I::Item>,
iter: I,
iter_orig: I,
}

/// Holds the current state during an iteration of a MultiProduct.
#[derive(Debug)]
enum MultiProductIterState {
StartOfIter,
MidIter { on_first_iter: bool },
}

impl<I> MultiProduct<I>
where I: Iterator + Clone,
I::Item: Clone
impl<I> MultiProductIter<I>
where
I: Iterator + Clone,
I::Item: Clone,
{
/// Iterates the rightmost iterator, then recursively iterates iterators
/// to the left if necessary.
///
/// Returns true if the iteration succeeded, else false.
fn iterate_last(
multi_iters: &mut [MultiProductIter<I>],
mut state: MultiProductIterState
) -> bool {
use self::MultiProductIterState::*;

if let Some((last, rest)) = multi_iters.split_last_mut() {
let on_first_iter = match state {
StartOfIter => {
let on_first_iter = !last.in_progress();
state = MidIter { on_first_iter };
on_first_iter
},
MidIter { on_first_iter } => on_first_iter
};

if !on_first_iter {
last.iterate();
}

if last.in_progress() {
true
} else if MultiProduct::iterate_last(rest, state) {
last.reset();
last.iterate();
// If iterator is None twice consecutively, then iterator is
// empty; whole product is empty.
last.in_progress()
} else {
false
}
} else {
// Reached end of iterator list. On initialisation, return true.
// At end of iteration (final iterator finishes), finish.
match state {
StartOfIter => false,
MidIter { on_first_iter } => on_first_iter
}
}
}

/// Returns the unwrapped value of the next iteration.
fn curr_iterator(&self) -> Vec<I::Item> {
self.0.iter().map(|multi_iter| {
multi_iter.cur.clone().unwrap()
}).collect()
}

/// Returns true if iteration has started and has not yet finished; false
/// otherwise.
fn in_progress(&self) -> bool {
if let Some(last) = self.0.last() {
last.in_progress()
} else {
false
}
fn reset(&mut self) {
self.iter = self.iter_orig.clone();
}
}

impl<I> MultiProductIter<I>
where I: Iterator + Clone,
I::Item: Clone
{
fn new(iter: I) -> Self {
MultiProductIter {
cur: None,
iter: iter.clone(),
iter_orig: iter
iter_orig: iter,
}
}

/// Iterate the managed iterator.
fn iterate(&mut self) {
self.cur = self.iter.next();
}

/// Reset the managed iterator.
fn reset(&mut self) {
self.iter = self.iter_orig.clone();
}

/// Returns true if the current iterator has been started and has not yet
/// finished; false otherwise.
fn in_progress(&self) -> bool {
self.cur.is_some()
fn next(&mut self) -> Option<I::Item> {
self.iter.next()
}
}

impl<I> Iterator for MultiProduct<I>
where I: Iterator + Clone,
I::Item: Clone
where
I: Iterator + Clone,
I::Item: Clone,
{
type Item = Vec<I::Item>;

fn next(&mut self) -> Option<Self::Item> {
if MultiProduct::iterate_last(
&mut self.0,
MultiProductIterState::StartOfIter
) {
Some(self.curr_iterator())
} else {
None
let last = match &mut self.state {
InProgress(v) => v,
Restarted(v) => {
let v = core::mem::take(v);
self.state = InProgress(v.clone());
return Some(v);
}
Unstarted => {
let next: Option<Vec<_>> = self.iters.iter_mut().map(|i| i.next()).collect();
if let Some(v) = &next {
self.state = InProgress(v.clone());
}
return next;
}
};

// Starting from the last iterator, advance each iterator until we find one that returns a
// value.
for i in (0..self.iters.len()).rev() {
let iter = &mut self.iters[i];
let loc = &mut last[i];
if let Some(val) = iter.next() {
*loc = val;
return Some(last.clone());
} else {
iter.reset();
if let Some(val) = iter.next() {
*loc = val;
} else {
// This case should not really take place; we had an in progress iterator, reset
// it, and called `.next()`, but now its empty. In any case, the product is
// empty now and we should handle things accordingly.
self.state = Unstarted;
return None;
}
}
}

// Reaching here indicates that all the iterators returned none, and so iteration has completed
let v = core::mem::take(last);
self.state = Restarted(v);
None
}

fn count(self) -> usize {
if self.0.is_empty() {
return 0;
}

if !self.in_progress() {
return self.0.into_iter().fold(1, |acc, multi_iter| {
acc * multi_iter.iter.count()
});
// `remaining` is the number of remaining iterations before the current iterator is
// exhausted. `per_reset` is the number of total iterations that take place each time the
// current iterator is reset
let (remaining, per_reset) =
self.iters
.into_iter()
.rev()
.fold((0, 1), |(remaining, per_reset), iter| {
let remaining = remaining + per_reset * iter.iter.count();
let per_reset = per_reset * iter.iter_orig.count();
(remaining, per_reset)
});
if let Restarted(_) | Unstarted = &self.state {
per_reset
} else {
remaining
}

self.0.into_iter().fold(
0,
|acc, MultiProductIter { iter, iter_orig, cur: _ }| {
let total_count = iter_orig.count();
let cur_count = iter.count();
acc * total_count + cur_count
}
)
}

fn size_hint(&self) -> (usize, Option<usize>) {
// Not ExactSizeIterator because size may be larger than usize
if self.0.is_empty() {
return (0, Some(0));
}

if !self.in_progress() {
return self.0.iter().fold((1, Some(1)), |acc, multi_iter| {
size_hint::mul(acc, multi_iter.iter.size_hint())
});
let initial = ((0, Some(0)), (1, Some(1)));
// Exact same logic as for `count`
let (remaining, per_reset) =
self.iters
.iter()
.rev()
.fold(initial, |(remaining, per_reset), iter| {
let prod = size_hint::mul(per_reset, iter.iter.size_hint());
let remaining = size_hint::add(remaining, prod);
let per_reset = size_hint::mul(per_reset, iter.iter_orig.size_hint());
(remaining, per_reset)
});
if let Restarted(_) | Unstarted = &self.state {
per_reset
} else {
remaining
}

self.0.iter().fold(
(0, Some(0)),
|acc, &MultiProductIter { ref iter, ref iter_orig, cur: _ }| {
let cur_size = iter.size_hint();
let total_size = iter_orig.size_hint();
size_hint::add(size_hint::mul(acc, total_size), cur_size)
}
)
}

fn last(self) -> Option<Self::Item> {
let iter_count = self.0.len();

let lasts: Self::Item = self.0.into_iter()
.map(|multi_iter| multi_iter.iter.last())
.while_some()
.collect();

if lasts.len() == iter_count {
Some(lasts)
// The way resetting works makes the first iterator a little bit special
let mut iter = self.iters.into_iter();
if let Some(first) = iter.next() {
let first = if let Restarted(_) | Unstarted = &self.state {
first.iter_orig.last()
} else {
first.iter.last()
};
core::iter::once(first)
.chain(iter.map(|sub| sub.iter_orig.last()))
.collect()
} else {
None
if let Restarted(_) | Unstarted = &self.state {
Some(Vec::new())
} else {
None
}
}
}
}
6 changes: 6 additions & 0 deletions tests/quick.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,12 @@ quickcheck! {
assert_eq!(answer.into_iter().last(), a.clone().multi_cartesian_product().last());
}

fn correct_empty_multi_product() -> () {
let mut empty = Vec::<std::vec::IntoIter<i32>>::new().into_iter().multi_cartesian_product();
assert!(correct_size_hint(empty.clone()));
assert_eq!(empty.next(), Some(Vec::new()))
}

#[allow(deprecated)]
fn size_step(a: Iter<i16, Exact>, s: usize) -> bool {
let mut s = s;
Expand Down

0 comments on commit 1a1d6cc

Please sign in to comment.