diff --git a/src/future/shared.rs b/src/future/shared.rs index 82f3d5f5000..ea422ea3694 100644 --- a/src/future/shared.rs +++ b/src/future/shared.rs @@ -15,41 +15,22 @@ use std::mem; use std::vec::Vec; -use std::sync::{Arc, RwLock}; -use std::sync::atomic::AtomicBool; -use std::sync::atomic::Ordering::SeqCst; +use std::sync::{Arc, Mutex}; use std::ops::Deref; use {Future, Poll, Async}; use task::{self, Task}; -use lock::Lock; /// A future that is cloneable and can be polled in multiple threads. /// Use Future::shared() method to convert any future into a `Shared` future. #[must_use = "futures do nothing unless polled"] -pub struct Shared - where F: Future -{ - inner: Arc>, -} - -struct Inner - where F: Future -{ - /// The original future. - original_future: Lock>, - /// Indicates whether the result is ready, and the state is `State::Done`. - result_ready: AtomicBool, - /// The state of the shared future. - state: RwLock>, +pub struct Shared { + inner: Arc>>, } -/// The state of the shared future. It can be one of the following: -/// 1. Done - contains the result of the original future. -/// 2. Waiting - contains the waiting tasks. -enum State { - Waiting(Vec), - Done(Result, Arc>), +enum State { + Waiting(F, Vec), + Done(Result, Arc>), } impl Shared @@ -58,23 +39,7 @@ impl Shared /// Creates a new `Shared` from another future. pub fn new(future: F) -> Self { Shared { - inner: Arc::new(Inner { - original_future: Lock::new(Some(future)), - result_ready: AtomicBool::new(false), - state: RwLock::new(State::Waiting(vec![])), - }), - } - } - - fn park(&self) -> Poll, SharedError> { - let me = task::park(); - match *self.inner.state.write().unwrap() { - State::Waiting(ref mut list) => { - list.push(me); - Ok(Async::NotReady) - } - State::Done(Ok(ref e)) => Ok(SharedItem { item: e.clone() }.into()), - State::Done(Err(ref e)) => Err(SharedError { error: e.clone() }), + inner: Arc::new(Mutex::new(State::Waiting(future, Vec::new()))), } } } @@ -86,79 +51,40 @@ impl Future for Shared type Error = SharedError; fn poll(&mut self) -> Poll { - // The logic is as follows: - // 1. Check if the result is ready (with result_ready) - // - If the result is ready, return it. - // - Otherwise: - // 2. Try lock the self.inner.original_future: - // - If successfully locked, check again if the result is ready. - // If it's ready, just return it. - // Otherwise, poll the original future. - // If the future is ready, unpark the waiting tasks from - // self.inner.state and return the result. - // - If the future is not ready, or if the lock failed: - // 3. Lock the state for write. - // 4. If the state is `State::Done`, return the result. Otherwise: - // 5. Create a task, push it to the waiters vector, and return - // `Ok(Async::NotReady)`. - - if !self.inner.result_ready.load(SeqCst) { - match self.inner.original_future.try_lock() { - // We already saw the result wasn't ready, but after we've - // acquired the lock another thread could already have finished, - // so we check `result_ready` again. - Some(_) if self.inner.result_ready.load(SeqCst) => {} - - // If we lock the future, then try to push it towards - // completion. - Some(mut future) => { - let result = match future.as_mut().unwrap().poll() { - Ok(Async::NotReady) => { - drop(future); - return self.park() - } - Ok(Async::Ready(item)) => Ok(Arc::new(item)), - Err(error) => Err(Arc::new(error)), - }; - - // Free up resources associated with this future - *future = None; - - // Wake up everyone waiting on the future and store the - // result at the same time, flagging future pollers that - // we're done. - let waiters = { - let mut state = self.inner.state.write().unwrap(); - self.inner.result_ready.store(true, SeqCst); - - match mem::replace(&mut *state, State::Done(result)) { - State::Waiting(waiters) => waiters, - State::Done(_) => { - panic!("store_result() was called twice") - } - } - }; - for task in waiters { - task.unpark(); - } + let mut inner = self.inner.lock().unwrap(); + let result = match *inner { + State::Waiting(ref mut future, _) => Some(future.poll()), + State::Done(_) => None, + }; + let new_state = match result { + Some(Ok(Async::NotReady)) => None, + Some(Ok(Async::Ready(e))) => Some(State::Done(Ok(Arc::new(e)))), + Some(Err(e)) => Some(State::Done(Err(Arc::new(e)))), + None => None, + }; + let tasks_to_wake = match new_state { + Some(new) => { + match mem::replace(&mut *inner, new) { + State::Waiting(_, tasks) => tasks, + State::Done(_) => panic!(), } - - // Looks like someone else is making progress on the future, - // let's just wait for them. - None => return self.park(), } - } + None => Vec::new(), + }; - // If we're here then we should have finished the future, so assert the - // `Done` state and return the item/error. - let result = match *self.inner.state.read().unwrap() { - State::Done(ref result) => result.clone(), - State::Waiting(_) => panic!("still waiting, not done yet"), + let ret = match *inner { + State::Waiting(_, ref mut tasks) => { + tasks.push(task::park()); + Ok(Async::NotReady) + } + State::Done(Ok(ref e)) => Ok(SharedItem { item: e.clone() }.into()), + State::Done(Err(ref e)) => Err(SharedError { error: e.clone() }.into()), }; - match result { - Ok(e) => Ok(SharedItem { item: e }.into()), - Err(e) => Err(SharedError { error: e }), + drop(inner); + for task in tasks_to_wake { + task.unpark(); } + return ret } } @@ -185,13 +111,23 @@ impl Drop for Shared { // other waiting tasks whenever we're dropped. This should go through // and wake up any interested handles, and at least one of them should // make its way to blocking on the original future itself. - if self.inner.result_ready.load(SeqCst) { - return - } - let waiters = match *self.inner.state.write().unwrap() { - State::Waiting(ref mut waiters) => mem::replace(waiters, Vec::new()), + // + // Note, though, that we don't call `lock` here but rather we just call + // `try_lock`. This is done because during a `poll` above, when the lock + // is held, we may end up calling this drop function. If that happens + // then this `try_lock` will fail, or the `try_lock` will fail due to + // another thread holding the lock. In both cases we're guaranteed that + // some other thread/task other than us is blocked on the main future, + // so there's no work for us to do. + let mut inner = match self.inner.try_lock() { + Ok(inner) => inner, + Err(_) => return, + }; + let waiters = match *inner { + State::Waiting(_, ref mut waiters) => mem::replace(waiters, Vec::new()), State::Done(_) => return, }; + drop(inner); for waiter in waiters { waiter.unpark(); } diff --git a/tests/shared.rs b/tests/shared.rs index e1a418b36df..876c05f0e75 100644 --- a/tests/shared.rs +++ b/tests/shared.rs @@ -1,26 +1,26 @@ extern crate futures; +use std::cell::RefCell; +use std::rc::Rc; use std::thread; use futures::sync::oneshot; use futures::Future; +use futures::future; fn send_shared_oneshot_and_wait_on_multiple_threads(threads_number: u32) { let (tx, rx) = oneshot::channel::(); let f = rx.shared(); - let mut cloned_futures_waited_oneshots = vec![]; - for _ in 0..threads_number { + let threads = (0..threads_number).map(|_| { let cloned_future = f.clone(); - let (tx2, rx2) = oneshot::channel::<()>(); - cloned_futures_waited_oneshots.push(rx2); thread::spawn(move || { assert!(*cloned_future.wait().unwrap() == 6); - tx2.complete(()); - }); - } + }) + }).collect::>(); tx.complete(6); - for f in cloned_futures_waited_oneshots { - f.wait().unwrap(); + assert!(*f.wait().unwrap() == 6); + for f in threads { + f.join().unwrap(); } } @@ -66,3 +66,16 @@ fn drop_on_one_task_ok() { assert_eq!(result, 42); t2.join().unwrap(); } + +#[test] +fn drop_in_poll() { + let slot = Rc::new(RefCell::new(None)); + let slot2 = slot.clone(); + let future = future::poll_fn(move || { + drop(slot2.borrow_mut().take().unwrap()); + Ok::<_, u32>(1.into()) + }).shared(); + let future2 = Box::new(future.clone()) as Box>; + *slot.borrow_mut() = Some(future2); + assert_eq!(*future.wait().unwrap(), 1); +}