Skip to content

Commit

Permalink
Refactor DekuRead to return bits read
Browse files Browse the repository at this point in the history
- Instead of always returning a new performance costly BitSlice, just
  return the amount of bits read and just increment the starting len
  into the BitSlice in DekuRead and DekuContainerRead functions.
  • Loading branch information
wcampbell0x2a committed Jun 15, 2023
1 parent 47b2f7e commit f3bc876
Show file tree
Hide file tree
Showing 57 changed files with 705 additions and 647 deletions.
6 changes: 3 additions & 3 deletions benches/deku.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,13 @@ fn criterion_benchmark(c: &mut Criterion) {
});

let deku_read_vec_input = {
let mut v = [0xFFu8; 101].to_vec();
let mut v = [0xffu8; 101].to_vec();
v[0] = 100u8;
v
};
let deku_write_vec_input = DekuVec {
count: 100,
data: vec![0xFF; 100],
data: vec![0xff; 100],
};
c.bench_function("deku_read_vec", |b| {
b.iter(|| deku_read_vec(black_box(&deku_read_vec_input)))
Expand All @@ -122,7 +122,7 @@ fn criterion_benchmark(c: &mut Criterion) {

let deku_write_vec_input = DekuVecPerf {
count: 100,
data: vec![0xFF; 100],
data: vec![0xff; 100],
};
c.bench_function("deku_read_vec_perf", |b| {
b.iter(|| deku_read_vec_perf(black_box(&deku_read_vec_input)))
Expand Down
36 changes: 21 additions & 15 deletions deku-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@ Procedural macros that implement `DekuRead` and `DekuWrite` traits

#![warn(missing_docs)]

use crate::macros::{deku_read::emit_deku_read, deku_write::emit_deku_write};
use std::borrow::Cow;
use std::convert::TryFrom;

use darling::{ast, FromDeriveInput, FromField, FromMeta, FromVariant, ToTokens};
use proc_macro2::TokenStream;
use quote::quote;
use std::borrow::Cow;
use std::convert::TryFrom;
use syn::{punctuated::Punctuated, spanned::Spanned, AttributeArgs};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::AttributeArgs;

use crate::macros::deku_read::emit_deku_read;
use crate::macros::deku_write::emit_deku_write;

mod macros;

Expand Down Expand Up @@ -210,7 +215,7 @@ impl DekuData {
} else {
Ok(())
}
}
},
ast::Data::Enum(_) => {
// Validate `type` or `id` is specified
if data.id_type.is_none() && data.id.is_none() {
Expand Down Expand Up @@ -251,7 +256,7 @@ impl DekuData {
}

Ok(())
}
},
}
}

Expand Down Expand Up @@ -692,7 +697,7 @@ fn map_litstr_as_tokenstream(
v.parse::<TokenStream>()
.expect("could not parse token stream"),
)
}
},
None => None,
})
}
Expand All @@ -709,7 +714,7 @@ fn gen_field_ident<T: ToString>(ident: Option<T>, index: usize, prefix: bool) ->
let index = syn::Index::from(index);
let prefix = if prefix { "field_" } else { "" };
format!("{}{}", prefix, quote! { #index })
}
},
};

field_name.parse().unwrap()
Expand Down Expand Up @@ -907,15 +912,15 @@ fn remove_deku_attrs(fields: &mut syn::Fields) {
match fields {
syn::Fields::Named(ref mut fields) => remove_deku_field_attrs(&mut fields.named),
syn::Fields::Unnamed(ref mut fields) => remove_deku_field_attrs(&mut fields.unnamed),
syn::Fields::Unit => {}
syn::Fields::Unit => {},
}
}

fn remove_temp_fields(fields: &mut syn::Fields) {
match fields {
syn::Fields::Named(ref mut fields) => remove_deku_temp_fields(&mut fields.named),
syn::Fields::Unnamed(ref mut fields) => remove_deku_temp_fields(&mut fields.unnamed),
syn::Fields::Unit => {}
syn::Fields::Unit => {},
}
}

Expand All @@ -941,7 +946,7 @@ pub fn deku_derive(
Ok(v) => v,
Err(e) => {
return proc_macro::TokenStream::from(e.write_errors());
}
},
};

// Parse item
Expand All @@ -966,7 +971,7 @@ pub fn deku_derive(
for variant in input_enum.variants.iter_mut() {
remove_temp_fields(&mut variant.fields)
}
}
},
_ => unimplemented!(),
}

Expand All @@ -982,13 +987,13 @@ pub fn deku_derive(
syn::Data::Struct(ref mut input_struct) => {
input.attrs.retain(is_not_deku);
remove_deku_attrs(&mut input_struct.fields)
}
},
syn::Data::Enum(ref mut input_enum) => {
for variant in input_enum.variants.iter_mut() {
variant.attrs.retain(is_not_deku);
remove_deku_attrs(&mut variant.fields)
}
}
},
_ => unimplemented!(),
}

Expand All @@ -1006,10 +1011,11 @@ pub fn deku_derive(

#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
use syn::parse_str;

use super::*;

#[rstest(input,
// Valid struct
case::struct_empty(r#"struct Test {}"#),
Expand Down
82 changes: 40 additions & 42 deletions deku-derive/src/macros/deku_read.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use std::convert::TryFrom;

use darling::ast::{Data, Fields};
use darling::ToTokens;
use proc_macro2::TokenStream;
use quote::quote;
use syn::spanned::Spanned;

use crate::macros::{
gen_ctx_types_and_arg, gen_field_args, gen_internal_field_ident, gen_internal_field_idents,
gen_type_from_ctx_id, pad_bits, token_contains_string, wrap_default_ctx,
};
use crate::{DekuData, DekuDataEnum, DekuDataStruct, FieldData, Id};
use darling::{
ast::{Data, Fields},
ToTokens,
};
use proc_macro2::TokenStream;
use quote::quote;
use std::convert::TryFrom;
use syn::spanned::Spanned;

pub(crate) fn emit_deku_read(input: &DekuData) -> Result<TokenStream, syn::Error> {
match &input.data {
Expand Down Expand Up @@ -64,19 +64,15 @@ fn emit_struct(input: &DekuData) -> Result<TokenStream, syn::Error> {
use core::convert::TryFrom;
use ::#crate_::bitvec::BitView;
let __deku_input_bits = __deku_input.0.view_bits::<::#crate_::bitvec::Msb0>();

let mut __deku_rest = __deku_input_bits;
__deku_rest = &__deku_rest[__deku_input.1..];
let mut __deku_rest = &__deku_input_bits[__deku_input.1..];
let mut __deku_total_read = 0;

#magic_read

#(#field_reads)*
let __deku_value = #initialize_struct;

let __deku_pad = 8 * ((__deku_rest.len() + 7) / 8) - __deku_rest.len();
let __deku_read_idx = __deku_input_bits.len() - (__deku_rest.len() + __deku_pad);

Ok(((__deku_input_bits[__deku_read_idx..].domain().region().unwrap().1, __deku_pad), __deku_value))
Ok((__deku_total_read, __deku_value))
},
&input.ctx,
&input.ctx_default,
Expand All @@ -98,18 +94,19 @@ fn emit_struct(input: &DekuData) -> Result<TokenStream, syn::Error> {
let read_body = quote! {
use core::convert::TryFrom;
let mut __deku_rest = __deku_input_bits;
let mut __deku_total_read = 0;

#magic_read

#(#field_reads)*
let __deku_value = #initialize_struct;

Ok((__deku_rest, __deku_value))
Ok((__deku_total_read, __deku_value))
};

tokens.extend(quote! {
impl #imp ::#crate_::DekuRead<#lifetime, #ctx_types> for #ident #wher {
fn read(__deku_input_bits: &#lifetime ::#crate_::bitvec::BitSlice<u8, ::#crate_::bitvec::Msb0>, #ctx_arg) -> core::result::Result<(&#lifetime ::#crate_::bitvec::BitSlice<u8, ::#crate_::bitvec::Msb0>, Self), ::#crate_::DekuError> {
fn read(__deku_input_bits: &#lifetime ::#crate_::bitvec::BitSlice<u8, ::#crate_::bitvec::Msb0>, #ctx_arg) -> core::result::Result<(usize, Self), ::#crate_::DekuError> {
#read_body
}
}
Expand All @@ -120,7 +117,7 @@ fn emit_struct(input: &DekuData) -> Result<TokenStream, syn::Error> {

tokens.extend(quote! {
impl #imp ::#crate_::DekuRead<#lifetime> for #ident #wher {
fn read(__deku_input_bits: &#lifetime ::#crate_::bitvec::BitSlice<u8, ::#crate_::bitvec::Msb0>, _: ()) -> core::result::Result<(&#lifetime ::#crate_::bitvec::BitSlice<u8, ::#crate_::bitvec::Msb0>, Self), ::#crate_::DekuError> {
fn read(__deku_input_bits: &#lifetime ::#crate_::bitvec::BitSlice<u8, ::#crate_::bitvec::Msb0>, _: ()) -> core::result::Result<(usize, Self), ::#crate_::DekuError> {
#read_body
}
}
Expand Down Expand Up @@ -229,7 +226,8 @@ fn emit_enum(input: &DekuData) -> Result<TokenStream, syn::Error> {
// if we're consuming an id, set the rest to new_rest before reading the variant
let new_rest = if consume_id {
quote! {
__deku_rest = __deku_new_rest;
__deku_rest = &__deku_rest[__deku_amt_read..];
__deku_total_read += __deku_amt_read;
}
} else {
quote! {}
Expand Down Expand Up @@ -289,11 +287,11 @@ fn emit_enum(input: &DekuData) -> Result<TokenStream, syn::Error> {

let variant_id_read = if id.is_some() {
quote! {
let (__deku_new_rest, __deku_variant_id) = (__deku_rest, (#id));
let (__deku_amt_read, __deku_variant_id) = (0, (#id));
}
} else if id_type.is_some() {
quote! {
let (__deku_new_rest, __deku_variant_id) = <#id_type>::read(__deku_rest, (#id_args))?;
let (__deku_amt_read, __deku_variant_id) = <#id_type>::read(__deku_rest, (#id_args))?;
}
} else {
// either `id` or `type` needs to be specified
Expand All @@ -317,18 +315,14 @@ fn emit_enum(input: &DekuData) -> Result<TokenStream, syn::Error> {
use core::convert::TryFrom;
use ::#crate_::bitvec::BitView;
let __deku_input_bits = __deku_input.0.view_bits::<::#crate_::bitvec::Msb0>();

let mut __deku_rest = __deku_input_bits;
__deku_rest = &__deku_rest[__deku_input.1..];
let mut __deku_rest = &__deku_input_bits[__deku_input.1..];
let mut __deku_total_read = 0;

#magic_read

#variant_read

let __deku_pad = 8 * ((__deku_rest.len() + 7) / 8) - __deku_rest.len();
let __deku_read_idx = __deku_input_bits.len() - (__deku_rest.len() + __deku_pad);

Ok(((__deku_input_bits[__deku_read_idx..].domain().region().unwrap().1, __deku_pad), __deku_value))
Ok((__deku_total_read, __deku_value))
},
&input.ctx,
&input.ctx_default,
Expand All @@ -349,18 +343,19 @@ fn emit_enum(input: &DekuData) -> Result<TokenStream, syn::Error> {
let read_body = quote! {
use core::convert::TryFrom;
let mut __deku_rest = __deku_input_bits;
let mut __deku_total_read = 0;

#magic_read

#variant_read

Ok((__deku_rest, __deku_value))
Ok((__deku_total_read, __deku_value))
};

tokens.extend(quote! {
#[allow(non_snake_case)]
impl #imp ::#crate_::DekuRead<#lifetime, #ctx_types> for #ident #wher {
fn read(__deku_input_bits: &#lifetime ::#crate_::bitvec::BitSlice<u8, ::#crate_::bitvec::Msb0>, #ctx_arg) -> core::result::Result<(&#lifetime ::#crate_::bitvec::BitSlice<u8, ::#crate_::bitvec::Msb0>, Self), ::#crate_::DekuError> {
fn read(__deku_input_bits: &#lifetime ::#crate_::bitvec::BitSlice<u8, ::#crate_::bitvec::Msb0>, #ctx_arg) -> core::result::Result<(usize, Self), ::#crate_::DekuError> {
#read_body
}
}
Expand All @@ -372,7 +367,7 @@ fn emit_enum(input: &DekuData) -> Result<TokenStream, syn::Error> {
tokens.extend(quote! {
#[allow(non_snake_case)]
impl #imp ::#crate_::DekuRead<#lifetime> for #ident #wher {
fn read(__deku_input_bits: &#lifetime ::#crate_::bitvec::BitSlice<u8, ::#crate_::bitvec::Msb0>, _: ()) -> core::result::Result<(&#lifetime ::#crate_::bitvec::BitSlice<u8, ::#crate_::bitvec::Msb0>, Self), ::#crate_::DekuError> {
fn read(__deku_input_bits: &#lifetime ::#crate_::bitvec::BitSlice<u8, ::#crate_::bitvec::Msb0>, _: ()) -> core::result::Result<(usize, Self), ::#crate_::DekuError> {
#read_body
}
}
Expand Down Expand Up @@ -414,12 +409,13 @@ fn emit_magic_read(input: &DekuData) -> TokenStream {
let __deku_magic = #magic;

for __deku_byte in __deku_magic {
let (__deku_new_rest, __deku_read_byte) = u8::read(__deku_rest, ())?;
let (__deku_amt_read, __deku_read_byte) = u8::read(__deku_rest, ())?;
if *__deku_byte != __deku_read_byte {
return Err(::#crate_::DekuError::Parse(format!("Missing magic value {:?}", #magic)));
}

__deku_rest = __deku_new_rest;
__deku_rest = &__deku_rest[__deku_amt_read..];
__deku_total_read += __deku_amt_read;
}
}
} else {
Expand Down Expand Up @@ -497,6 +493,7 @@ fn emit_padding(bit_size: &TokenStream) -> TokenStream {
if __deku_rest.len() >= __deku_pad {
let (__deku_padded_bits, __deku_new_rest) = __deku_rest.split_at(__deku_pad);
__deku_rest = __deku_new_rest;
__deku_total_read += __deku_pad;
} else {
return Err(::#crate_::DekuError::Incomplete(::#crate_::error::NeedSize::new(__deku_pad)));
}
Expand Down Expand Up @@ -655,10 +652,11 @@ fn emit_field_read(
);

let field_read_normal = quote! {
let (__deku_new_rest, __deku_value) = #field_read_func?;
let (__deku_amt_read, __deku_value) = #field_read_func?;
let __deku_value: #field_type = #field_map(__deku_value)?;

__deku_rest = __deku_new_rest;
__deku_rest = &__deku_rest[__deku_amt_read..];
__deku_total_read += __deku_amt_read;

__deku_value
};
Expand All @@ -675,13 +673,13 @@ fn emit_field_read(
#field_read_normal
}
}
}
},
(true, None) => {
// #[deku(skip)] ==> `skip`
quote! {
#field_default
}
}
},
(false, Some(field_cond)) => {
// #[deku(cond = "...")] ==> read if `cond`
quote! {
Expand All @@ -691,12 +689,12 @@ fn emit_field_read(
#field_default
}
}
}
},
(false, None) => {
quote! {
#field_read_normal
}
}
},
};

let field_read = quote! {
Expand Down Expand Up @@ -732,7 +730,7 @@ pub fn emit_from_bytes(
quote! {
impl #imp ::#crate_::DekuContainerRead<#lifetime> for #ident #wher {
#[allow(non_snake_case)]
fn from_bytes(__deku_input: (&#lifetime [u8], usize)) -> core::result::Result<((&#lifetime [u8], usize), Self), ::#crate_::DekuError> {
fn from_bytes(__deku_input: (&#lifetime [u8], usize)) -> core::result::Result<(usize, Self), ::#crate_::DekuError> {
#body
}
}
Expand All @@ -752,8 +750,8 @@ pub fn emit_try_from(
type Error = ::#crate_::DekuError;

fn try_from(input: &#lifetime [u8]) -> core::result::Result<Self, Self::Error> {
let (rest, res) = <Self as ::#crate_::DekuContainerRead>::from_bytes((input, 0))?;
if !rest.0.is_empty() {
let (amt_read, res) = <Self as ::#crate_::DekuContainerRead>::from_bytes((input, 0))?;
if (amt_read / 8) != input.len() {
return Err(::#crate_::DekuError::Parse(format!("Too much data")));
}
Ok(res)
Expand Down
Loading

0 comments on commit f3bc876

Please sign in to comment.