Skip to content

Commit

Permalink
add reentrancy guard
Browse files Browse the repository at this point in the history
  • Loading branch information
trmid committed Jun 27, 2024
1 parent 36aabd7 commit 0651d14
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/Claimer.sol
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { UD2x18 } from "prb-math/UD2x18.sol";
import { UD60x18, convert } from "prb-math/UD60x18.sol";
import { PrizePool } from "pt-v5-prize-pool/PrizePool.sol";
import { SafeCast } from "openzeppelin/utils/math/SafeCast.sol";
import { ReentrancyGuard } from "openzeppelin/security/ReentrancyGuard.sol";

import { LinearVRGDALib } from "./libraries/LinearVRGDALib.sol";
import { IClaimable } from "pt-v5-claimable-interface/interfaces/IClaimable.sol";
Expand All @@ -32,7 +33,7 @@ error TimeToReachMaxFeeZero();
/// @title Variable Rate Gradual Dutch Auction (VRGDA) Claimer
/// @author G9 Software Inc.
/// @notice This contract uses a variable rate gradual dutch auction to incentivize prize claims on behalf of others. Fees for each canary tier is set to the respective tier's prize size.
contract Claimer {
contract Claimer is ReentrancyGuard {

/// @notice Emitted when a claim reverts
/// @param vault The vault for which the claim failed
Expand Down Expand Up @@ -94,7 +95,7 @@ contract Claimer {
uint32[][] calldata _prizeIndices,
address _feeRecipient,
uint256 _minFeePerClaim
) external returns (uint256 totalFees) {
) external nonReentrant returns (uint256 totalFees) {
bool feeRecipientZeroAddress = address(0) == _feeRecipient;
if (feeRecipientZeroAddress && _minFeePerClaim != 0) {
revert FeeRecipientZeroAddress();
Expand Down
27 changes: 27 additions & 0 deletions test/Claimer.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { SD59x18 } from "prb-math/SD59x18.sol";
import { PrizePool, AlreadyClaimed } from "pt-v5-prize-pool/PrizePool.sol";
import { IClaimable } from "pt-v5-claimable-interface/interfaces/IClaimable.sol";
import { LinearVRGDALib } from "../src/libraries/LinearVRGDALib.sol";
import { ReentrancyMock } from "./mock/ReentrancyMock.sol";

// Custom Errors
error ClaimArraySizeMismatch(uint256 winnersLength, uint256 prizeIndicesLength);
Expand Down Expand Up @@ -126,6 +127,32 @@ contract ClaimerTest is Test {
assertEq(totalFees, NO_SALES_100_SECONDS_BEHIND_SCHEDULE_FEE, "Total fees");
}

function testClaimPrizes_reentrancyGuard() public {
address[] memory winners = newWinners(winner1, winner2);
uint32[][] memory prizeIndices = newPrizeIndices(2, 1);

address[] memory reentrancyWinners = newWinners(winner3);
uint32[][] memory reentrancyPrizeIndices = newPrizeIndices(1, 1);

ReentrancyMock reentrancyVault = new ReentrancyMock(address(claimer));
reentrancyVault.setReentrancyClaimInfo(
winner1, // only triggerred by winner1
IClaimable(address(reentrancyVault)),
1,
reentrancyWinners,
reentrancyPrizeIndices,
address(this),
0
);

mockPrizePool(1, -100, 0);

vm.expectEmit();
emit ClaimError(IClaimable(address(reentrancyVault)), 1, winner1, 0, abi.encodeWithSignature("Error(string)", "ReentrancyGuard: reentrant call"));
uint256 totalFees = claimer.claimPrizes(reentrancyVault, 1, winners, prizeIndices, address(this), 0);
assertLt(totalFees, NO_SALES_100_SECONDS_BEHIND_SCHEDULE_FEE, "Total fees"); // 2 fee-split expected, but one fails so the received fees are slightly less
}

function testClaimPrizes_singleNoFeeSavesGas() public {
// With fee
address[] memory winners = newWinners(winner1);
Expand Down
56 changes: 56 additions & 0 deletions test/mock/ReentrancyMock.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@

// SPDX-License-Identifier: MIT
pragma solidity ^0.8.24;

import { IClaimable } from "pt-v5-claimable-interface/interfaces/IClaimable.sol";
import { Claimer } from "../../src/Claimer.sol";

contract ReentrancyMock is IClaimable {

address public claimer;
address public badGuy;
bytes public reentrancyCalldata;

constructor(address claimer_) {
claimer = claimer_;
}

function claimPrize(
address _winner,
uint8 _tier,
uint32 _prizeIndex,
uint96 _reward,
address _rewardRecipient
) external returns (uint256) {
if (_winner == badGuy) {
(bool success, bytes memory data) = claimer.call(reentrancyCalldata);
require(success == false, "reentrancy succeeded...");
assembly {
revert(add(32, data), mload(data))
}
}
return 1;
}

function setReentrancyClaimInfo(
address _badGuy,
IClaimable _vault,
uint8 _tier,
address[] calldata _winners,
uint32[][] calldata _prizeIndices,
address _feeRecipient,
uint256 _minFeePerClaim
) external returns (uint256) {
badGuy = _badGuy;
reentrancyCalldata = abi.encodeWithSelector(
Claimer.claimPrizes.selector,
_vault,
_tier,
_winners,
_prizeIndices,
_feeRecipient,
_minFeePerClaim
);
}

}

0 comments on commit 0651d14

Please sign in to comment.