Skip to content

Commit

Permalink
Fix out-of-bound (OOB) input read in AES-XTS Decrypt in AVX-512 imple…
Browse files Browse the repository at this point in the history
…mentation (#2227)

- Fix instruction that caused out-of-bound read in the input reading of
the 16x loop (which processes a batch of 16 blocks of AES, 1 block = 16
bytes). This was triggered on lengths that are in the range 
[16*k * (16 bytes), (16*k +3)* (16 bytes)-1], k = 1, 2, ... 
The instruction was reading up to 3*16 bytes beyond the input length bound.

- The fix was inspired by the 8x loop in
/~https://github.com/aws/aws-lc/blob/becf5785c131012bb5a64f3da6cdb117ddc0f431/crypto/fipsmodule/aes/asm/aesni-xts-avx512.pl#L2544

- The existing unit tests cover those cases but there were no explicit
memory protections and ASAN doesn't instrument assembly code to check
for out-of-bound reads even when the subsequent memory is explicitly
poisoned.
 
### Call-outs:
N/A

### Testing:
On c6i, without the fix, the unit test segfaults
```
./crypto/crypto_test "--gtest_filter=XTSTest.*"
Note: Google Test filter = XTSTest.*
[==========] Running 4 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 4 tests from XTSTest
[ RUN      ] XTSTest.TestVectors
Segmentation fault (core dumped)
```
By submitting this pull request, I confirm that my contribution is made
under the terms of the Apache 2.0 license and the ISC license.
  • Loading branch information
nebeid authored Feb 28, 2025
1 parent 50e6d59 commit eb0c0c0
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 22 deletions.
2 changes: 1 addition & 1 deletion crypto/fipsmodule/aes/asm/aesni-xts-avx512.pl
Original file line number Diff line number Diff line change
Expand Up @@ -2493,7 +2493,7 @@
vmovdqu8 0x40($input),%zmm2
vmovdqu8 0x80($input),%zmm3
vmovdqu8 0xc0($input),%zmm4
vmovdqu8 0xf0($input),%zmm5
vmovdqu8 0xf0($input),%xmm5
add \$0x100,$input
___
}
Expand Down
72 changes: 54 additions & 18 deletions crypto/fipsmodule/modes/xts_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
#include "internal.h"
#include "../../test/test_util.h"


#if defined(OPENSSL_LINUX)
#include <sys/mman.h>
#endif
struct XTSTestCase {
const char *key_hex;
const char *iv_hex;
Expand Down Expand Up @@ -995,8 +997,30 @@ static const XTSTestCase kXTSTestCases[] = {
},
};

#if defined(OPENSSL_LINUX)
static uint8_t *get_buffer_end(int pagesize) {
uint8_t *two_pages_p = (uint8_t *)mmap(NULL, 2*pagesize, PROT_READ|PROT_WRITE,
MAP_PRIVATE|MAP_ANONYMOUS, -1, 0);
EXPECT_TRUE(two_pages_p != NULL) << "mmap returned NULL.";

int ret = mprotect(two_pages_p + pagesize, pagesize, PROT_NONE);
EXPECT_TRUE(ret == 0) << "mprotect failed.";

return two_pages_p + pagesize;
}

static void free_memory(uint8_t *addr, int pagesize) {
munmap(addr - pagesize, 2 * pagesize);
}
#endif

TEST(XTSTest, TestVectors) {
unsigned test_num = 0;
#if defined(OPENSSL_LINUX)
int pagesize = sysconf(_SC_PAGE_SIZE);
uint8_t *in_buffer_end = get_buffer_end(pagesize);
uint8_t *out_buffer_end = get_buffer_end(pagesize);
#endif
for (const auto &test : kXTSTestCases) {
test_num++;
SCOPED_TRACE(test_num);
Expand All @@ -1013,45 +1037,57 @@ TEST(XTSTest, TestVectors) {
ASSERT_EQ(EVP_CIPHER_iv_length(cipher), iv.size());
ASSERT_EQ(plaintext.size(), ciphertext.size());

int len;
uint8_t *in_p, *out_p;
#if defined(OPENSSL_LINUX)
ASSERT_GE(pagesize, (int)plaintext.size());
in_p = in_buffer_end - plaintext.size();
out_p = out_buffer_end - plaintext.size();
OPENSSL_memset(in_p, 0x00, plaintext.size());
OPENSSL_memset(out_p, 0x00, plaintext.size());
#else
std::unique_ptr<uint8_t[]> in(new uint8_t[plaintext.size()]);
std::unique_ptr<uint8_t[]> out(new uint8_t[plaintext.size()]);
in_p = in.get();
out_p = out.get();
#endif

// Note XTS doesn't support streaming, so we only test single-shot inputs.
for (bool in_place : {false, true}) {
SCOPED_TRACE(in_place);

// Test encryption.
bssl::Span<const uint8_t> in = plaintext;
std::vector<uint8_t> out(plaintext.size());

OPENSSL_memcpy(in_p, plaintext.data(), plaintext.size());
if (in_place) {
out = plaintext;
in = out;
out_p = in_p;
}

bssl::ScopedEVP_CIPHER_CTX ctx;
ASSERT_TRUE(EVP_EncryptInit_ex(ctx.get(), cipher, nullptr, key.data(),
iv.data()));
int len;
ASSERT_TRUE(
EVP_EncryptUpdate(ctx.get(), out.data(), &len, in.data(), in.size()));
out.resize(len);
EXPECT_EQ(Bytes(ciphertext), Bytes(out));
EVP_EncryptUpdate(ctx.get(), out_p, &len, in_p, plaintext.size()));
EXPECT_EQ(Bytes(ciphertext), Bytes(out_p, static_cast<size_t>(len)));

// Test decryption.
in = ciphertext;
out.clear();
out.resize(plaintext.size());
if (in_place) {
out = ciphertext;
in = out;

if (!in_place) {
OPENSSL_memset(in_p, 0, len);
}

ctx.Reset();
ASSERT_TRUE(EVP_DecryptInit_ex(ctx.get(), cipher, nullptr, key.data(),
iv.data()));
ASSERT_TRUE(
EVP_DecryptUpdate(ctx.get(), out.data(), &len, in.data(), in.size()));
out.resize(len);
EXPECT_EQ(Bytes(plaintext), Bytes(out));
EVP_DecryptUpdate(ctx.get(), in_p, &len, out_p, ciphertext.size()));
EXPECT_EQ(Bytes(plaintext), Bytes(in_p, static_cast<size_t>(len)));
}
}
#if defined(OPENSSL_LINUX)
free_memory(in_buffer_end, pagesize);
free_memory(out_buffer_end, pagesize);
#endif
}

// Negative test for key1 = key2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3271,7 +3271,7 @@ aes_hw_xts_decrypt_avx512:
vmovdqu8 64(%rdi),%zmm2
vmovdqu8 128(%rdi),%zmm3
vmovdqu8 192(%rdi),%zmm4
vmovdqu8 240(%rdi),%zmm5
vmovdqu8 240(%rdi),%xmm5
addq $0x100,%rdi
vpxorq %zmm9,%zmm1,%zmm1
vpxorq %zmm10,%zmm2,%zmm2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3271,7 +3271,7 @@ L$_main_loop_run_16_amivrujEyduiFoi:
vmovdqu8 64(%rdi),%zmm2
vmovdqu8 128(%rdi),%zmm3
vmovdqu8 192(%rdi),%zmm4
vmovdqu8 240(%rdi),%zmm5
vmovdqu8 240(%rdi),%xmm5
addq $0x100,%rdi
vpxorq %zmm9,%zmm1,%zmm1
vpxorq %zmm10,%zmm2,%zmm2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3330,7 +3330,7 @@ $L$_main_loop_run_16_amivrujEyduiFoi:
vmovdqu8 zmm2,ZMMWORD[64+rcx]
vmovdqu8 zmm3,ZMMWORD[128+rcx]
vmovdqu8 zmm4,ZMMWORD[192+rcx]
vmovdqu8 zmm5,ZMMWORD[240+rcx]
vmovdqu8 xmm5,XMMWORD[240+rcx]
add rcx,0x100
vpxorq zmm1,zmm1,zmm9
vpxorq zmm2,zmm2,zmm10
Expand Down

0 comments on commit eb0c0c0

Please sign in to comment.