diff --git a/deku-derive/src/lib.rs b/deku-derive/src/lib.rs index f760fac6..9104737f 100644 --- a/deku-derive/src/lib.rs +++ b/deku-derive/src/lib.rs @@ -341,6 +341,12 @@ struct FieldData { /// condition to parse field cond: Option, + + // assertion on field + assert: Option, + + // assert value of field + assert_eq: Option, } impl FieldData { @@ -374,6 +380,8 @@ impl FieldData { temp: receiver.temp, default: receiver.default?, cond: receiver.cond?, + assert: receiver.assert?, + assert_eq: receiver.assert_eq?, }; FieldData::validate(&data)?; @@ -716,6 +724,14 @@ struct DekuFieldReceiver { /// condition to parse field #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")] cond: Result, ReplacementError>, + + // assertion on field + #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")] + assert: Result, ReplacementError>, + + // assert value of field + #[darling(default = "default_res_opt", map = "map_litstr_as_tokenstream")] + assert_eq: Result, ReplacementError>, } /// Receiver for the variant-level attributes inside a enum diff --git a/deku-derive/src/macros/deku_read.rs b/deku-derive/src/macros/deku_read.rs index 48a6fb0c..7450de44 100644 --- a/deku-derive/src/macros/deku_read.rs +++ b/deku-derive/src/macros/deku_read.rs @@ -467,6 +467,7 @@ fn emit_field_read( let field_reader = &f.reader; + // fields to check usage of bit/byte offset let field_check_vars = [ &f.count, &f.bits_read, @@ -477,6 +478,8 @@ fn emit_field_read( &f.map, &f.reader, &f.ctx.as_ref().map(|v| quote!(#v)), + &f.assert, + &f.assert_eq, ]; let (bit_offset, byte_offset) = emit_bit_byte_offsets(&field_check_vars); @@ -490,9 +493,39 @@ fn emit_field_read( .or_else(|| Some(quote! { Result::<_, DekuError>::Ok })); let field_ident = f.get_ident(i, true); - + let field_ident_str = field_ident.to_string(); let internal_field_ident = gen_internal_field_ident(&field_ident); + let field_assert = f.assert.as_ref().map(|v| { + quote! { + if (!(#v)) { + // assertion is false, raise error + return Err(DekuError::Assertion(format!( + "field '{}' failed assertion: {}", + #field_ident_str, + stringify!(#v) + ))); + } else { + // do nothing + } + } + }); + + let field_assert_eq = f.assert_eq.as_ref().map(|v| { + quote! { + if (!(#internal_field_ident == (#v))) { + // assertion is false, raise error + return Err(DekuError::Assertion(format!( + "field '{}' failed assertion: {}", + #field_ident_str, + stringify!(#field_ident == #v) + ))); + } else { + // do nothing + } + } + }); + let field_read_func = if field_reader.is_some() { quote! { #field_reader } } else { @@ -610,6 +643,9 @@ fn emit_field_read( }; let #field_ident = &#internal_field_ident; + #field_assert + #field_assert_eq + #pad_bits_after }; diff --git a/deku-derive/src/macros/deku_write.rs b/deku-derive/src/macros/deku_write.rs index 3763ebc9..fe86a65d 100644 --- a/deku-derive/src/macros/deku_write.rs +++ b/deku-derive/src/macros/deku_write.rs @@ -483,12 +483,50 @@ fn emit_field_write( ) -> Result { let field_endian = f.endian.as_ref().or_else(|| input.endian.as_ref()); - let field_check_vars = [&f.writer, &f.cond, &f.ctx.as_ref().map(|v| quote!(#v))]; + // fields to check usage of bit/byte offset + let field_check_vars = [ + &f.writer, + &f.cond, + &f.ctx.as_ref().map(|v| quote!(#v)), + &f.assert, + &f.assert_eq, + ]; let (bit_offset, byte_offset) = emit_bit_byte_offsets(&field_check_vars); let field_writer = &f.writer; let field_ident = f.get_ident(i, object_prefix.is_none()); + let field_ident_str = field_ident.to_string(); + + let field_assert = f.assert.as_ref().map(|v| { + quote! { + if (!(#v)) { + // assertion is false, raise error + return Err(DekuError::Assertion(format!( + "field '{}' failed assertion: {}", + #field_ident_str, + stringify!(#v) + ))); + } else { + // do nothing + } + } + }); + + let field_assert_eq = f.assert_eq.as_ref().map(|v| { + quote! { + if (!(*(#field_ident) == (#v))) { + // assertion is false, raise error + return Err(DekuError::Assertion(format!( + "field '{}' failed assertion: {}", + #field_ident_str, + stringify!(#field_ident == #v) + ))); + } else { + // do nothing + } + } + }); let field_write_func = if field_writer.is_some() { quote! { #field_writer } @@ -551,6 +589,9 @@ fn emit_field_write( #bit_offset #byte_offset + #field_assert + #field_assert_eq + #field_write_tokens #pad_bits_after diff --git a/src/attributes.rs b/src/attributes.rs index 9f8a9b32..8aff3446 100644 --- a/src/attributes.rs +++ b/src/attributes.rs @@ -7,6 +7,8 @@ A documentation-only module for #\[deku\] attributes |-----------|------------------|------------ | [endian](#endian) | top-level, field | Set the endianness | [magic](#magic) | top-level | A magic value that must be present at the start of this struct/enum +| [assert](#assert) | field | Assert a condition +| [assert_eq](#assert_eq) | field | Assert equals on the field | [bits](#bits) | field | Set the bit-size of the field | [bytes](#bytes) | field | Set the byte-size of the field | [count](#count) | field | Set the field representing the element count of a container @@ -139,6 +141,63 @@ let value: Vec = value.try_into().unwrap(); assert_eq!(data, value); ``` +# assert + +Assert a condition after reading and before writing a field + +Example: +```rust +# use deku::prelude::*; +# use std::convert::{TryInto, TryFrom}; +# #[derive(Debug, PartialEq, DekuRead, DekuWrite)] +struct DekuTest { + #[deku(assert = "*data >= 8")] + data: u8 +} + +let data: Vec = vec![0x00, 0x01, 0x02]; + +let value = DekuTest::try_from(data.as_ref()); + +assert_eq!( + Err(DekuError::Assertion("field 'data' failed assertion: * data >= 8".into())), + value +); +``` + +# assert_eq + +Assert equals after reading and before writing a field + +Example: +```rust +# use deku::prelude::*; +# use std::convert::{TryInto, TryFrom}; +# #[derive(Debug, PartialEq, DekuRead, DekuWrite)] +struct DekuTest { + #[deku(assert_eq = "0x01")] + data: u8, +} + +let data: Vec = vec![0x01]; + +let mut value = DekuTest::try_from(data.as_ref()).unwrap(); + +assert_eq!( + DekuTest { data: 0x01 }, + value +); + +value.data = 0x02; + +let value: Result, DekuError> = value.try_into(); + +assert_eq!( + Err(DekuError::Assertion("field 'data' failed assertion: data == 0x01".into())), + value +); +``` + # bits Set the bit-size of the field diff --git a/src/error.rs b/src/error.rs index 6964b829..0e8d77ed 100644 --- a/src/error.rs +++ b/src/error.rs @@ -13,6 +13,8 @@ pub enum DekuError { InvalidParam(String), /// Unexpected error Unexpected(String), + /// Assertion error from `assert` or `assert_eq` attributes + Assertion(String), } impl From for DekuError { @@ -39,6 +41,7 @@ impl core::fmt::Display for DekuError { DekuError::Parse(ref err) => write!(f, "Parse error: {}", err), DekuError::InvalidParam(ref err) => write!(f, "Invalid param error: {}", err), DekuError::Unexpected(ref err) => write!(f, "Unexpected error: {}", err), + DekuError::Assertion(ref err) => write!(f, "Assertion error: {}", err), } } } diff --git a/tests/test_attributes/mod.rs b/tests/test_attributes/mod.rs index 22ba8275..25eca934 100644 --- a/tests/test_attributes/mod.rs +++ b/tests/test_attributes/mod.rs @@ -1,3 +1,5 @@ +mod test_assert; +mod test_assert_eq; mod test_cond; mod test_ctx; mod test_limits; diff --git a/tests/test_attributes/test_assert.rs b/tests/test_attributes/test_assert.rs new file mode 100644 index 00000000..5a306d74 --- /dev/null +++ b/tests/test_attributes/test_assert.rs @@ -0,0 +1,42 @@ +use deku::prelude::*; +use hexlit::hex; +use rstest::rstest; +use std::convert::{TryFrom, TryInto}; + +#[derive(Default, PartialEq, Debug, DekuRead, DekuWrite)] +struct TestStruct { + field_a: u8, + #[deku(assert = "*field_a + *field_b >= 3")] + field_b: u8, +} + +#[rstest(input, expected, + case(&hex!("0102"), TestStruct { + field_a: 0x01, + field_b: 0x02, + }), + + #[should_panic(expected = r#"Assertion("field \'field_b\' failed assertion: * field_a + * field_b >= 3")"#)] + case(&hex!("0101"), TestStruct::default()) +)] +fn test_assert_read(input: &[u8], expected: TestStruct) { + let ret_read = TestStruct::try_from(input).unwrap(); + assert_eq!(expected, ret_read); +} + +#[rstest(input, expected, + case(TestStruct { + field_a: 0x01, + field_b: 0x02, + }, hex!("0102").to_vec()), + + #[should_panic(expected = r#"Assertion("field \'field_b\' failed assertion: * field_a + * field_b >= 3")"#)] + case(TestStruct { + field_a: 0x01, + field_b: 0x01, + }, hex!("").to_vec()), +)] +fn test_assert_write(input: TestStruct, expected: Vec) { + let ret_write: Vec = input.try_into().unwrap(); + assert_eq!(expected, ret_write); +} diff --git a/tests/test_attributes/test_assert_eq.rs b/tests/test_attributes/test_assert_eq.rs new file mode 100644 index 00000000..2c356ae5 --- /dev/null +++ b/tests/test_attributes/test_assert_eq.rs @@ -0,0 +1,42 @@ +use deku::prelude::*; +use hexlit::hex; +use rstest::rstest; +use std::convert::{TryFrom, TryInto}; + +#[derive(Default, PartialEq, Debug, DekuRead, DekuWrite)] +struct TestStruct { + field_a: u8, + #[deku(assert_eq = "*field_a")] + field_b: u8, +} + +#[rstest(input, expected, + case(&hex!("0101"), TestStruct { + field_a: 0x01, + field_b: 0x01, + }), + + #[should_panic(expected = r#"Assertion("field \'field_b\' failed assertion: field_b == * field_a")"#)] + case(&hex!("0102"), TestStruct::default()) +)] +fn test_assert_eq_read(input: &[u8], expected: TestStruct) { + let ret_read = TestStruct::try_from(input).unwrap(); + assert_eq!(expected, ret_read); +} + +#[rstest(input, expected, + case(TestStruct { + field_a: 0x01, + field_b: 0x01, + }, hex!("0101").to_vec()), + + #[should_panic(expected = r#"Assertion("field \'field_b\' failed assertion: field_b == * field_a")"#)] + case(TestStruct { + field_a: 0x01, + field_b: 0x02, + }, hex!("").to_vec()), +)] +fn test_assert_eq_write(input: TestStruct, expected: Vec) { + let ret_write: Vec = input.try_into().unwrap(); + assert_eq!(expected, ret_write); +}