From 59fc7603e2eb2ac3e500927d98514f431d691481 Mon Sep 17 00:00:00 2001 From: wcampbell Date: Mon, 21 Oct 2024 19:06:31 -0400 Subject: [PATCH] Add count Vec Specializations (#481) * For Vec when using count, specialize into reading the bytes all at once See #462 --- benches/deku.rs | 42 +++++++++++++++-- deku-derive/src/lib.rs | 5 +- deku-derive/src/macros/deku_read.rs | 46 +++++++++++++++---- ensure_no_std/src/bin/main.rs | 7 ++- src/attributes.rs | 3 +- src/ctx.rs | 3 ++ src/impls/vec.rs | 14 ++++++ .../test_attributes/test_limits/test_count.rs | 26 ++++++++++- 8 files changed, 126 insertions(+), 20 deletions(-) diff --git a/benches/deku.rs b/benches/deku.rs index 79514336..80dd55a2 100644 --- a/benches/deku.rs +++ b/benches/deku.rs @@ -111,7 +111,7 @@ fn criterion_benchmark(c: &mut Criterion) { }); } -pub fn read_all_vs_count(c: &mut Criterion) { +pub fn read_all_vs_count_vs_read_exact(c: &mut Criterion) { #[derive(DekuRead, DekuWrite)] pub struct AllWrapper { #[deku(read_all)] @@ -119,8 +119,20 @@ pub fn read_all_vs_count(c: &mut Criterion) { } #[derive(DekuRead, DekuWrite)] - #[deku(ctx = "len: usize")] pub struct CountWrapper { + #[deku(count = "1500")] + pub data: Vec, + } + + #[derive(DekuRead, DekuWrite)] + pub struct CountNonSpecialize { + #[deku(count = "(1500/2)")] + pub data: Vec, + } + + #[derive(DekuRead, DekuWrite)] + #[deku(ctx = "len: usize")] + pub struct CountFromCtxWrapper { #[deku(count = "len")] pub data: Vec, } @@ -137,14 +149,34 @@ pub fn read_all_vs_count(c: &mut Criterion) { }) }); - c.bench_function("count", |b| { + c.bench_function("count_specialize", |b| { + b.iter(|| { + let mut cursor = Cursor::new([1u8; 1500].as_ref()); + let mut reader = Reader::new(&mut cursor); + CountWrapper::from_reader_with_ctx(black_box(&mut reader), ()) + }) + }); + + c.bench_function("count_from_u8_specialize", |b| { + b.iter(|| { + let mut cursor = Cursor::new([1u8; 1500].as_ref()); + let mut reader = Reader::new(&mut cursor); + CountWrapper::from_reader_with_ctx(black_box(&mut reader), ()) + }) + }); + + c.bench_function("count_no_specialize", |b| { b.iter(|| { let mut cursor = Cursor::new([1u8; 1500].as_ref()); let mut reader = Reader::new(&mut cursor); - CountWrapper::from_reader_with_ctx(black_box(&mut reader), 1500) + CountNonSpecialize::from_reader_with_ctx(black_box(&mut reader), ()) }) }); } -criterion_group!(benches, criterion_benchmark, read_all_vs_count); +criterion_group!( + benches, + criterion_benchmark, + read_all_vs_count_vs_read_exact +); criterion_main!(benches); diff --git a/deku-derive/src/lib.rs b/deku-derive/src/lib.rs index 20e891d9..3d1a4251 100644 --- a/deku-derive/src/lib.rs +++ b/deku-derive/src/lib.rs @@ -8,6 +8,7 @@ extern crate alloc; use alloc::borrow::Cow; use std::convert::TryFrom; use std::fmt::Display; +use syn::Type; use darling::{ast, FromDeriveInput, FromField, FromMeta, FromVariant, ToTokens}; use proc_macro2::TokenStream; @@ -412,7 +413,7 @@ impl<'a> TryFrom<&'a DekuData> for DekuDataStruct<'a> { #[derive(Debug)] struct FieldData { ident: Option, - ty: syn::Type, + ty: Type, /// endianness for the field endian: Option, @@ -859,7 +860,7 @@ fn default_res_opt() -> Result, E> { #[darling(attributes(deku))] struct DekuFieldReceiver { ident: Option, - ty: syn::Type, + ty: Type, /// Endianness for the field #[darling(default)] diff --git a/deku-derive/src/macros/deku_read.rs b/deku-derive/src/macros/deku_read.rs index 5208633a..4331ad49 100644 --- a/deku-derive/src/macros/deku_read.rs +++ b/deku-derive/src/macros/deku_read.rs @@ -777,14 +777,44 @@ fn emit_field_read( } } } else if let Some(field_count) = &f.count { - quote! { - { - use core::borrow::Borrow; - #type_as_deku_read::from_reader_with_ctx - ( - __deku_reader, - (::#crate_::ctx::Limit::new_count(usize::try_from(*((#field_count).borrow()))?), (#read_args)) - )? + use syn::{GenericArgument, PathArguments, Type}; + let mut is_vec_u8 = false; + if let Type::Path(type_path) = &f.ty { + if type_path.path.segments.len() == 1 && type_path.path.segments[0].ident == "Vec" { + if let PathArguments::AngleBracketed(ref generic_args) = + type_path.path.segments[0].arguments + { + if generic_args.args.len() == 1 { + if let GenericArgument::Type(Type::Path(ref arg_path)) = + generic_args.args[0] + { + is_vec_u8 = arg_path.path.is_ident("u8"); + } + } + } + } + } + if is_vec_u8 { + quote! { + { + use core::borrow::Borrow; + #type_as_deku_read::from_reader_with_ctx + ( + __deku_reader, + ::#crate_::ctx::ReadExact(usize::try_from(*((#field_count).borrow()))?) + )? + } + } + } else { + quote! { + { + use core::borrow::Borrow; + #type_as_deku_read::from_reader_with_ctx + ( + __deku_reader, + (::#crate_::ctx::Limit::new_count(usize::try_from(*((#field_count).borrow()))?), (#read_args)) + )? + } } } } else if let Some(field_bytes) = &f.bytes_read { diff --git a/ensure_no_std/src/bin/main.rs b/ensure_no_std/src/bin/main.rs index 3c8ac7f6..67729557 100644 --- a/ensure_no_std/src/bin/main.rs +++ b/ensure_no_std/src/bin/main.rs @@ -24,6 +24,8 @@ struct DekuTest { count: u8, #[deku(count = "count", pad_bytes_after = "8")] data: Vec, + #[deku(count = "1")] + after: Vec, } #[entry] @@ -39,7 +41,7 @@ fn main() -> ! { // now the allocator is ready types like Box, Vec can be used. #[allow(clippy::unusual_byte_groupings)] - let test_data: &[u8] = &[0b10101_101, 0x02, 0xBE, 0xEF, 0xff]; + let test_data: &[u8] = &[0b10101_101, 0x02, 0xBE, 0xEF, 0xff, 0xaa]; let mut cursor = deku::no_std_io::Cursor::new(test_data); // Test reading @@ -49,7 +51,8 @@ fn main() -> ! { field_a: 0b10101, field_b: 0b101, count: 0x02, - data: vec![0xBE, 0xEF] + data: vec![0xBE, 0xEF], + after: vec![0xaa], }, val ); diff --git a/src/attributes.rs b/src/attributes.rs index 5ed86166..e9b6146b 100644 --- a/src/attributes.rs +++ b/src/attributes.rs @@ -649,6 +649,8 @@ assert_eq!(data, value); **Note**: See [update](#update) for more information on the attribute! +## Specializations +- `Vec`: `count` used with a byte vector will result in one invocation to `read_bytes`, thus improving performance. # bytes_read @@ -769,7 +771,6 @@ let value: Vec = value.try_into().unwrap(); assert_eq!(&*data, value); ``` - # update Specify custom code to run on the field when `.update()` is called on the struct/enum diff --git a/src/ctx.rs b/src/ctx.rs index 5565be21..393a1a99 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -201,3 +201,6 @@ impl BitSize { Self::bits_from_reader(core::mem::size_of_val(val)) } } + +/// Amount of bytes to read_exact +pub struct ReadExact(pub usize); diff --git a/src/impls/vec.rs b/src/impls/vec.rs index 33c89c99..fa93bcb0 100644 --- a/src/impls/vec.rs +++ b/src/impls/vec.rs @@ -8,6 +8,20 @@ use crate::writer::Writer; use crate::{ctx::*, DekuReader}; use crate::{DekuError, DekuWriter}; +impl<'a> DekuReader<'a, ReadExact> for Vec { + fn from_reader_with_ctx( + reader: &mut Reader, + exact: ReadExact, + ) -> Result + where + Self: Sized, + { + let mut bytes = alloc::vec![0x00; exact.0]; + let _ = reader.read_bytes(exact.0, &mut bytes)?; + Ok(bytes) + } +} + /// Read `T`s into a vec until a given predicate returns true /// * `capacity` - an optional capacity to pre-allocate the vector with /// * `ctx` - The context required by `T`. It will be passed to every `T` when constructing. diff --git a/tests/test_attributes/test_limits/test_count.rs b/tests/test_attributes/test_limits/test_count.rs index c2e78712..544711e7 100644 --- a/tests/test_attributes/test_limits/test_count.rs +++ b/tests/test_attributes/test_limits/test_count.rs @@ -27,6 +27,28 @@ mod test_slice { assert_eq!(test_data, ret_write); } + #[test] + fn test_count_static_non_u8() { + #[derive(PartialEq, Debug, DekuRead, DekuWrite)] + struct TestStruct { + #[deku(count = "1")] + data: Vec<(u8, u8)>, + } + + let test_data: Vec = [0xaa, 0xbb].to_vec(); + + let ret_read = TestStruct::try_from(test_data.as_slice()).unwrap(); + assert_eq!( + TestStruct { + data: vec![(0xaa, 0xbb)], + }, + ret_read + ); + + let ret_write: Vec = ret_read.try_into().unwrap(); + assert_eq!(test_data, ret_write); + } + #[test] fn test_count_from_field() { #[derive(PartialEq, Debug, DekuRead, DekuWrite)] @@ -74,7 +96,7 @@ mod test_slice { } #[test] - #[should_panic(expected = "Incomplete(NeedSize { bits: 8 })")] + #[should_panic(expected = "Incomplete(NeedSize { bits: 24 })")] fn test_count_error() { #[derive(PartialEq, Debug, DekuRead, DekuWrite)] struct TestStruct { @@ -156,7 +178,7 @@ mod test_vec { } #[test] - #[should_panic(expected = "Incomplete(NeedSize { bits: 8 })")] + #[should_panic(expected = "Incomplete(NeedSize { bits: 24 })")] fn test_count_error() { #[derive(PartialEq, Debug, DekuRead, DekuWrite)] struct TestStruct {