Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Zip::apply_collect for non-Copy elements too #814

Merged
merged 3 commits into from
Apr 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions benches/bench1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,38 @@ fn add_2d_alloc_zip_collect(bench: &mut test::Bencher) {
});
}

#[bench]
fn vec_string_collect(bench: &mut test::Bencher) {
let v = vec![""; 10240];
bench.iter(|| {
v.iter().map(|s| s.to_owned()).collect::<Vec<_>>()
});
}

#[bench]
fn array_string_collect(bench: &mut test::Bencher) {
let v = Array::from(vec![""; 10240]);
bench.iter(|| {
Zip::from(&v).apply_collect(|s| s.to_owned())
});
}

#[bench]
fn vec_f64_collect(bench: &mut test::Bencher) {
let v = vec![1.; 10240];
bench.iter(|| {
v.iter().map(|s| s + 1.).collect::<Vec<_>>()
});
}

#[bench]
fn array_f64_collect(bench: &mut test::Bencher) {
let v = Array::from(vec![1.; 10240]);
bench.iter(|| {
Zip::from(&v).apply_collect(|s| s + 1.)
});
}


#[bench]
fn add_2d_assign_ops(bench: &mut test::Bencher) {
Expand Down
26 changes: 19 additions & 7 deletions src/zip/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#[macro_use]
mod zipmacro;
mod partial_array;

use std::mem::MaybeUninit;

Expand All @@ -20,6 +21,8 @@ use crate::NdIndex;
use crate::indexes::{indices, Indices};
use crate::layout::{CORDER, FORDER};

use partial_array::PartialArray;

/// Return if the expression is a break value.
macro_rules! fold_while {
($e:expr) => {
Expand Down Expand Up @@ -195,6 +198,7 @@ pub trait NdProducer {
fn split_at(self, axis: Axis, index: usize) -> (Self, Self)
where
Self: Sized;

private_decl! {}
}

Expand Down Expand Up @@ -1070,16 +1074,24 @@ macro_rules! map_impl {
/// inputs.
///
/// If all inputs are c- or f-order respectively, that is preserved in the output.
///
/// Restricted to functions that produce copyable results for technical reasons; other
/// cases are not yet implemented.
pub fn apply_collect<R>(self, f: impl FnMut($($p::Item,)* ) -> R) -> Array<R, D>
where R: Copy,
{
// To support non-Copy elements, implementation of dropping partial array (on
// panic) is needed
// Make uninit result
let mut output = self.uninitalized_for_current_layout::<R>();
self.apply_assign_into(&mut output, f);
if !std::mem::needs_drop::<R>() {
// For elements with no drop glue, just overwrite into the array
self.apply_assign_into(&mut output, f);
} else {
// For generic elements, use a proxy that counts the number of filled elements,
// and can drop the right number of elements on unwinding
unsafe {
PartialArray::scope(output.view_mut(), move |partial| {
debug_assert_eq!(partial.layout().tendency() >= 0, self.layout_tendency >= 0);
self.apply_assign_into(partial, f);
});
}
}

unsafe {
output.assume_init()
}
Expand Down
144 changes: 144 additions & 0 deletions src/zip/partial_array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// Copyright 2020 bluss and ndarray developers.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use crate::imp_prelude::*;
use crate::{
AssignElem,
Layout,
NdProducer,
Zip,
FoldWhile,
};

use std::cell::Cell;
use std::mem;
use std::mem::MaybeUninit;
use std::ptr;

/// An assignable element reference that increments a counter when assigned
pub(crate) struct ProxyElem<'a, 'b, A> {
item: &'a mut MaybeUninit<A>,
filled: &'b Cell<usize>
}

impl<'a, 'b, A> AssignElem<A> for ProxyElem<'a, 'b, A> {
fn assign_elem(self, item: A) {
self.filled.set(self.filled.get() + 1);
*self.item = MaybeUninit::new(item);
}
}

/// Handles progress of assigning to a part of an array, for elements that need
/// to be dropped on unwinding. See Self::scope.
pub(crate) struct PartialArray<'a, 'b, A, D>
where D: Dimension
{
data: ArrayViewMut<'a, MaybeUninit<A>, D>,
filled: &'b Cell<usize>,
}

impl<'a, 'b, A, D> PartialArray<'a, 'b, A, D>
where D: Dimension
{
/// Create a temporary PartialArray that wraps the array view `data`;
/// if the end of the scope is reached, the partial array is marked complete;
/// if execution unwinds at any time before them, the elements written until then
/// are dropped.
///
/// Safety: the caller *must* ensure that elements will be written in `data`'s preferred order.
/// PartialArray can not handle arbitrary writes, only in the memory order.
pub(crate) unsafe fn scope(data: ArrayViewMut<'a, MaybeUninit<A>, D>,
scope_fn: impl FnOnce(&mut PartialArray<A, D>))
{
let filled = Cell::new(0);
let mut partial = PartialArray::new(data, &filled);
scope_fn(&mut partial);
filled.set(0); // mark complete
}

unsafe fn new(data: ArrayViewMut<'a, MaybeUninit<A>, D>,
filled: &'b Cell<usize>) -> Self
{
debug_assert_eq!(filled.get(), 0);
Self { data, filled }
}
}

impl<'a, 'b, A, D> Drop for PartialArray<'a, 'b, A, D>
where D: Dimension
{
fn drop(&mut self) {
if !mem::needs_drop::<A>() {
return;
}

let mut count = self.filled.get();
if count == 0 {
return;
}

Zip::from(self).fold_while((), move |(), elt| {
if count > 0 {
count -= 1;
unsafe {
ptr::drop_in_place::<A>(elt.item.as_mut_ptr());
}
FoldWhile::Continue(())
} else {
FoldWhile::Done(())
}
});
}
}

impl<'a: 'c, 'b: 'c, 'c, A, D: Dimension> NdProducer for &'c mut PartialArray<'a, 'b, A, D> {
// This just wraps ArrayViewMut as NdProducer and maps the item
type Item = ProxyElem<'a, 'b, A>;
type Dim = D;
type Ptr = *mut MaybeUninit<A>;
type Stride = isize;

private_impl! {}
fn raw_dim(&self) -> Self::Dim {
self.data.raw_dim()
}

fn equal_dim(&self, dim: &Self::Dim) -> bool {
self.data.equal_dim(dim)
}

fn as_ptr(&self) -> Self::Ptr {
NdProducer::as_ptr(&self.data)
}

fn layout(&self) -> Layout {
self.data.layout()
}

unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
ProxyElem { filled: self.filled, item: &mut *ptr }
}

unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
self.data.uget_ptr(i)
}

fn stride_of(&self, axis: Axis) -> Self::Stride {
self.data.stride_of(axis)
}

#[inline(always)]
fn contiguous_stride(&self) -> Self::Stride {
self.data.contiguous_stride()
}

fn split_at(self, _axis: Axis, _index: usize) -> (Self, Self) {
unimplemented!();
}
}

72 changes: 72 additions & 0 deletions tests/azip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,78 @@ fn test_zip_assign_into_cell() {
assert_abs_diff_eq!(a2, &b + &c, epsilon = 1e-6);
}

#[test]
fn test_zip_collect_drop() {
use std::cell::RefCell;
use std::panic;

struct Recorddrop<'a>((usize, usize), &'a RefCell<Vec<(usize, usize)>>);

impl<'a> Drop for Recorddrop<'a> {
fn drop(&mut self) {
self.1.borrow_mut().push(self.0);
}
}

#[derive(Copy, Clone)]
enum Config {
CC,
CF,
FF,
}

impl Config {
fn a_is_f(self) -> bool {
match self {
Config::CC | Config::CF => false,
_ => true,
}
}
fn b_is_f(self) -> bool {
match self {
Config::CC => false,
_ => true,
}
}
}

let test_collect_panic = |config: Config, will_panic: bool, slice: bool| {
let mut inserts = RefCell::new(Vec::new());
let mut drops = RefCell::new(Vec::new());

let mut a = Array::from_shape_fn((5, 10).set_f(config.a_is_f()), |idx| idx);
let mut b = Array::from_shape_fn((5, 10).set_f(config.b_is_f()), |_| 0);
if slice {
a = a.slice_move(s![.., ..-1]);
b = b.slice_move(s![.., ..-1]);
}

let _result = panic::catch_unwind(panic::AssertUnwindSafe(|| {
Zip::from(&a).and(&b).apply_collect(|&elt, _| {
if elt.0 > 3 && will_panic {
panic!();
}
inserts.borrow_mut().push(elt);
Recorddrop(elt, &drops)
});
}));

println!("{:?}", inserts.get_mut());
println!("{:?}", drops.get_mut());

assert_eq!(inserts.get_mut().len(), drops.get_mut().len(), "Incorrect number of drops");
assert_eq!(inserts.get_mut(), drops.get_mut(), "Incorrect order of drops");
};

for &should_panic in &[true, false] {
for &should_slice in &[false, true] {
test_collect_panic(Config::CC, should_panic, should_slice);
test_collect_panic(Config::CF, should_panic, should_slice);
test_collect_panic(Config::FF, should_panic, should_slice);
}
}
}


#[test]
fn test_azip_syntax_trailing_comma() {
Expand Down