Skip to content

Commit

Permalink
Safer poll timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
JonathanWoollett-Light committed Nov 23, 2022
1 parent 7591f81 commit ed39c68
Show file tree
Hide file tree
Showing 3 changed files with 259 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ This project adheres to [Semantic Versioning](https://semver.org/).
([#1870](/~https://github.com/nix-rust/nix/pull/1870))
- The `length` argument of `sys::mman::mmap` is now of type `NonZeroUsize`.
([#1873](/~https://github.com/nix-rust/nix/pull/1873))
- The `timeout` argument of `poll::poll` is now of type `poll::PollTimeout`.
([#1876](/~https://github.com/nix-rust/nix/pull/1876))

### Fixed

Expand Down
256 changes: 254 additions & 2 deletions src/poll.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
//! Wait for events to trigger on specific file descriptors
use std::convert::TryFrom;
use std::fmt;
use std::os::unix::io::{AsRawFd, RawFd};
use std::time::Duration;

use crate::errno::Errno;
use crate::Result;
Expand Down Expand Up @@ -112,6 +115,255 @@ libc_bitflags! {
}
}

/// Timeout argument for [`poll`].
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub struct PollTimeout(i32);

/// Error type for [`PollTimeout::try_from::<i128>::()`].
#[derive(Debug, Clone, Copy)]
pub enum TryFromI128Error {
/// Value is less than -1.
Underflow(crate::Errno),
/// Value is greater than [`i32::MAX`].
Overflow(<i32 as TryFrom<i128>>::Error),
}
impl fmt::Display for TryFromI128Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Underflow(err) => write!(f, "Underflow: {}", err),
Self::Overflow(err) => write!(f, "Overflow: {}", err),
}
}
}
impl std::error::Error for TryFromI128Error {}

/// Error type for [`PollTimeout::try_from::<i68>()`].
#[derive(Debug, Clone, Copy)]
pub enum TryFromI64Error {
/// Value is less than -1.
Underflow(crate::Errno),
/// Value is greater than [`i32::MAX`].
Overflow(<i32 as TryFrom<i64>>::Error),
}
impl fmt::Display for TryFromI64Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Underflow(err) => write!(f, "Underflow: {}", err),
Self::Overflow(err) => write!(f, "Overflow: {}", err),
}
}
}
impl std::error::Error for TryFromI64Error {}

// These cases implement slightly different conversions that make using generics impossible without
// specialization.
impl PollTimeout {
/// Blocks indefinitely.
pub const NONE: Self = Self(1 << 31);
/// Returns immediately.
pub const ZERO: Self = Self(0);
/// Blocks for at most [`std::i32::MAX`] milliseconds.
pub const MAX: Self = Self(i32::MAX);
/// Returns if `self` equals [`PollTimeout::NONE`].
pub fn is_none(&self) -> bool {
*self == Self::NONE
}
/// Returns if `self` does not equal [`PollTimeout::NONE`].
pub fn is_some(&self) -> bool {
!self.is_none()
}
/// Returns the timeout in milliseconds if there is some, otherwise returns `None`.
pub fn timeout(&self) -> Option<i32> {
self.is_some().then(|| self.0)
}
}
impl TryFrom<Duration> for PollTimeout {
type Error = <i32 as TryFrom<u128>>::Error;
fn try_from(x: Duration) -> std::result::Result<Self, Self::Error> {
Ok(Self(i32::try_from(x.as_millis())?))
}
}
impl TryFrom<u128> for PollTimeout {
type Error = <i32 as TryFrom<u128>>::Error;
fn try_from(x: u128) -> std::result::Result<Self, Self::Error> {
Ok(Self(i32::try_from(x)?))
}
}
impl TryFrom<u64> for PollTimeout {
type Error = <i32 as TryFrom<u64>>::Error;
fn try_from(x: u64) -> std::result::Result<Self, Self::Error> {
Ok(Self(i32::try_from(x)?))
}
}
impl TryFrom<u32> for PollTimeout {
type Error = <i32 as TryFrom<u32>>::Error;
fn try_from(x: u32) -> std::result::Result<Self, Self::Error> {
Ok(Self(i32::try_from(x)?))
}
}
impl From<u16> for PollTimeout {
fn from(x: u16) -> Self {
Self(i32::from(x))
}
}
impl From<u8> for PollTimeout {
fn from(x: u8) -> Self {
Self(i32::from(x))
}
}
impl TryFrom<i128> for PollTimeout {
type Error = TryFromI128Error;
fn try_from(x: i128) -> std::result::Result<Self, Self::Error> {
match x {
-1 => Ok(Self::NONE),
millis @ 0.. => Ok(Self(
i32::try_from(millis).map_err(TryFromI128Error::Overflow)?,
)),
// EINVAL (ppoll()) The timeout value expressed in *ip is invalid (negative).
_ => Err(TryFromI128Error::Underflow(Errno::EINVAL)),
}
}
}
impl TryFrom<i64> for PollTimeout {
type Error = TryFromI64Error;
fn try_from(x: i64) -> std::result::Result<Self, Self::Error> {
match x {
-1 => Ok(Self::NONE),
millis @ 0.. => Ok(Self(
i32::try_from(millis).map_err(TryFromI64Error::Overflow)?,
)),
// EINVAL (ppoll()) The timeout value expressed in *ip is invalid (negative).
_ => Err(TryFromI64Error::Underflow(Errno::EINVAL)),
}
}
}
impl TryFrom<i32> for PollTimeout {
type Error = Errno;
fn try_from(x: i32) -> Result<Self> {
match x {
-1 => Ok(Self::NONE),
millis @ 0.. => Ok(Self(millis)),
// EINVAL (ppoll()) The timeout value expressed in *ip is invalid (negative).
_ => Err(Errno::EINVAL),
}
}
}
impl TryFrom<i16> for PollTimeout {
type Error = Errno;
fn try_from(x: i16) -> Result<Self> {
match x {
-1 => Ok(Self::NONE),
millis @ 0.. => Ok(Self(millis.into())),
// EINVAL (ppoll()) The timeout value expressed in *ip is invalid (negative).
_ => Err(Errno::EINVAL),
}
}
}
impl TryFrom<i8> for PollTimeout {
type Error = Errno;
fn try_from(x: i8) -> Result<Self> {
match x {
-1 => Ok(Self::NONE),
millis @ 0.. => Ok(Self(millis.into())),
// EINVAL (ppoll()) The timeout value expressed in *ip is invalid (negative).
_ => Err(Errno::EINVAL),
}
}
}
impl TryFrom<PollTimeout> for Duration {
type Error = ();
fn try_from(x: PollTimeout) -> std::result::Result<Self, ()> {
match x.timeout() {
// SAFETY: `x.0` is always positive.
Some(millis) => Ok(Duration::from_millis(unsafe {
u64::try_from(millis).unwrap_unchecked()
})),
None => Err(()),
}
}
}
impl TryFrom<PollTimeout> for u128 {
type Error = ();
fn try_from(x: PollTimeout) -> std::result::Result<Self, ()> {
match x.timeout() {
// SAFETY: When `x.timeout()` returns `Some(a)`, `a` is always positive.
Some(millis) => {
Ok(unsafe { Self::try_from(millis).unwrap_unchecked() })
}
None => Err(()),
}
}
}
impl TryFrom<PollTimeout> for u64 {
type Error = ();
fn try_from(x: PollTimeout) -> std::result::Result<Self, ()> {
match x.timeout() {
// SAFETY: When `x.timeout()` returns `Some(a)`, `a` is always positive.
Some(millis) => {
Ok(unsafe { Self::try_from(millis).unwrap_unchecked() })
}
None => Err(()),
}
}
}
impl TryFrom<PollTimeout> for u32 {
type Error = ();
fn try_from(x: PollTimeout) -> std::result::Result<Self, ()> {
match x.timeout() {
// SAFETY: When `x.timeout()` returns `Some(a)`, `a` is always positive.
Some(millis) => {
Ok(unsafe { Self::try_from(millis).unwrap_unchecked() })
}
None => Err(()),
}
}
}
impl TryFrom<PollTimeout> for u16 {
type Error = Option<<Self as TryFrom<i32>>::Error>;
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
match x.timeout() {
Some(millis) => Ok(Self::try_from(millis).map_err(Some)?),
None => Err(None),
}
}
}
impl TryFrom<PollTimeout> for u8 {
type Error = Option<<Self as TryFrom<i32>>::Error>;
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
match x.timeout() {
Some(millis) => Ok(Self::try_from(millis).map_err(Some)?),
None => Err(None),
}
}
}
impl From<PollTimeout> for i128 {
fn from(x: PollTimeout) -> Self {
x.timeout().unwrap_or(-1).into()
}
}
impl From<PollTimeout> for i64 {
fn from(x: PollTimeout) -> Self {
x.timeout().unwrap_or(-1).into()
}
}
impl From<PollTimeout> for i32 {
fn from(x: PollTimeout) -> Self {
x.timeout().unwrap_or(-1)
}
}
impl TryFrom<PollTimeout> for i16 {
type Error = <Self as TryFrom<i32>>::Error;
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
Self::try_from(x.timeout().unwrap_or(-1))
}
}
impl TryFrom<PollTimeout> for i8 {
type Error = <Self as TryFrom<i32>>::Error;
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
Self::try_from(x.timeout().unwrap_or(-1))
}
}

/// `poll` waits for one of a set of file descriptors to become ready to perform I/O.
/// ([`poll(2)`](https://pubs.opengroup.org/onlinepubs/9699919799/functions/poll.html))
///
Expand All @@ -132,12 +384,12 @@ libc_bitflags! {
/// in timeout means an infinite timeout. Specifying a timeout of zero
/// causes `poll()` to return immediately, even if no file descriptors are
/// ready.
pub fn poll(fds: &mut [PollFd], timeout: libc::c_int) -> Result<libc::c_int> {
pub fn poll(fds: &mut [PollFd], timeout: PollTimeout) -> Result<libc::c_int> {
let res = unsafe {
libc::poll(
fds.as_mut_ptr() as *mut libc::pollfd,
fds.len() as libc::nfds_t,
timeout,
timeout.into(),
)
};

Expand Down
6 changes: 3 additions & 3 deletions test/test_poll.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use nix::{
errno::Errno,
poll::{poll, PollFd, PollFlags},
poll::{poll, PollFd, PollFlags, PollTimeout},
unistd::{pipe, write},
};

Expand All @@ -22,14 +22,14 @@ fn test_poll() {
let mut fds = [PollFd::new(r, PollFlags::POLLIN)];

// Poll an idle pipe. Should timeout
let nfds = loop_while_eintr!(poll(&mut fds, 100));
let nfds = loop_while_eintr!(poll(&mut fds, PollTimeout::from(100u8)));
assert_eq!(nfds, 0);
assert!(!fds[0].revents().unwrap().contains(PollFlags::POLLIN));

write(w, b".").unwrap();

// Poll a readable pipe. Should return an event.
let nfds = poll(&mut fds, 100).unwrap();
let nfds = poll(&mut fds, PollTimeout::from(100u8)).unwrap();
assert_eq!(nfds, 1);
assert!(fds[0].revents().unwrap().contains(PollFlags::POLLIN));
}
Expand Down

0 comments on commit ed39c68

Please sign in to comment.