From 454efce83605988c967c5e97f6427b45220df520 Mon Sep 17 00:00:00 2001 From: Yuchen Wu Date: Thu, 30 Dec 2021 13:38:46 -0800 Subject: [PATCH] Fix SIMD header value check on char >= 0x80 The SIMD intrinsics *_cmpgt_epi8 are for signed chars. This change correctly performs the unsigned comparison. --- src/simd/avx2.rs | 32 ++++++++++++++++++++++++++------ src/simd/sse42.rs | 32 ++++++++++++++++++++++++++------ 2 files changed, 52 insertions(+), 12 deletions(-) diff --git a/src/simd/avx2.rs b/src/simd/avx2.rs index 83aea98..89b76ba 100644 --- a/src/simd/avx2.rs +++ b/src/simd/avx2.rs @@ -100,10 +100,11 @@ unsafe fn match_header_value_char_32_avx(buf: &[u8]) -> usize { // %x09 %x20-%x7e %x80-%xff let TAB: __m256i = _mm256_set1_epi8(0x09); let DEL: __m256i = _mm256_set1_epi8(0x7f); - let LOW: __m256i = _mm256_set1_epi8(0x1f); + let LOW: __m256i = _mm256_set1_epi8(0x20); let dat = _mm256_lddqu_si256(ptr as *const _); - let low = _mm256_cmpgt_epi8(dat, LOW); + // unsigned comparison dat >= LOW + let low = _mm256_cmpeq_epi8(_mm256_max_epu8(dat, LOW), dat); let tab = _mm256_cmpeq_epi8(dat, TAB); let del = _mm256_cmpeq_epi8(dat, DEL); let bit = _mm256_andnot_si256(del, _mm256_or_si256(low, tab)); @@ -126,11 +127,30 @@ fn avx2_code_matches_uri_chars_table() { } unsafe { - assert!(byte_is_allowed(b'_')); + assert!(byte_is_allowed(b'_', parse_uri_batch_32)); for (b, allowed) in ::URI_MAP.iter().cloned().enumerate() { assert_eq!( - byte_is_allowed(b as u8), allowed, + byte_is_allowed(b as u8, parse_uri_batch_32), allowed, + "byte_is_allowed({:?}) should be {:?}", b, allowed, + ); + } + } +} + +#[test] +fn avx2_code_matches_header_value_chars_table() { + match super::detect() { + super::AVX_2 | super::AVX_2_AND_SSE_42 => {}, + _ => return, + } + + unsafe { + assert!(byte_is_allowed(b'_', match_header_value_batch_32)); + + for (b, allowed) in ::HEADER_VALUE_MAP.iter().cloned().enumerate() { + assert_eq!( + byte_is_allowed(b as u8, match_header_value_batch_32), allowed, "byte_is_allowed({:?}) should be {:?}", b, allowed, ); } @@ -138,7 +158,7 @@ fn avx2_code_matches_uri_chars_table() { } #[cfg(test)] -unsafe fn byte_is_allowed(byte: u8) -> bool { +unsafe fn byte_is_allowed(byte: u8, f: unsafe fn(bytes: &mut Bytes<'_>) -> Scan) -> bool { let slice = [ b'_', b'_', b'_', b'_', b'_', b'_', b'_', b'_', @@ -151,7 +171,7 @@ unsafe fn byte_is_allowed(byte: u8) -> bool { ]; let mut bytes = Bytes::new(&slice); - parse_uri_batch_32(&mut bytes); + f(&mut bytes); match bytes.pos() { 32 => true, diff --git a/src/simd/sse42.rs b/src/simd/sse42.rs index 49cecbd..d50eebe 100644 --- a/src/simd/sse42.rs +++ b/src/simd/sse42.rs @@ -85,10 +85,11 @@ unsafe fn match_header_value_char_16_sse(buf: &[u8]) -> usize { // %x09 %x20-%x7e %x80-%xff let TAB: __m128i = _mm_set1_epi8(0x09); let DEL: __m128i = _mm_set1_epi8(0x7f); - let LOW: __m128i = _mm_set1_epi8(0x1f); + let LOW: __m128i = _mm_set1_epi8(0x20); let dat = _mm_lddqu_si128(ptr as *const _); - let low = _mm_cmpgt_epi8(dat, LOW); + // unsigned comparison dat >= LOW + let low = _mm_cmpeq_epi8(_mm_max_epu8(dat, LOW), dat); let tab = _mm_cmpeq_epi8(dat, TAB); let del = _mm_cmpeq_epi8(dat, DEL); let bit = _mm_andnot_si128(del, _mm_or_si128(low, tab)); @@ -106,11 +107,30 @@ fn sse_code_matches_uri_chars_table() { } unsafe { - assert!(byte_is_allowed(b'_')); + assert!(byte_is_allowed(b'_', parse_uri_batch_16)); for (b, allowed) in ::URI_MAP.iter().cloned().enumerate() { assert_eq!( - byte_is_allowed(b as u8), allowed, + byte_is_allowed(b as u8, parse_uri_batch_16), allowed, + "byte_is_allowed({:?}) should be {:?}", b, allowed, + ); + } + } +} + +#[test] +fn sse_code_matches_header_value_chars_table() { + match super::detect() { + super::SSE_42 | super::AVX_2_AND_SSE_42 => {}, + _ => return, + } + + unsafe { + assert!(byte_is_allowed(b'_', match_header_value_batch_16)); + + for (b, allowed) in ::HEADER_VALUE_MAP.iter().cloned().enumerate() { + assert_eq!( + byte_is_allowed(b as u8, match_header_value_batch_16), allowed, "byte_is_allowed({:?}) should be {:?}", b, allowed, ); } @@ -118,7 +138,7 @@ fn sse_code_matches_uri_chars_table() { } #[cfg(test)] -unsafe fn byte_is_allowed(byte: u8) -> bool { +unsafe fn byte_is_allowed(byte: u8, f: unsafe fn(bytes: &mut Bytes<'_>)) -> bool { let slice = [ b'_', b'_', b'_', b'_', b'_', b'_', b'_', b'_', @@ -127,7 +147,7 @@ unsafe fn byte_is_allowed(byte: u8) -> bool { ]; let mut bytes = Bytes::new(&slice); - parse_uri_batch_16(&mut bytes); + f(&mut bytes); match bytes.pos() { 16 => true,