diff --git a/src/constants/BitMask.sol b/src/constants/BitMask.sol new file mode 100644 index 0000000..c749f9f --- /dev/null +++ b/src/constants/BitMask.sol @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.0; + +uint256 constant MASK_1_BIT = 0x1; +uint256 constant MASK_8_BITS = 0xff; +uint256 constant MASK_24_BITS = 0xffffff; +uint256 constant MASK_127_BITS = 0x7fffffffffffffffffffffffffffffff; +uint256 constant MASK_128_BITS = 0xffffffffffffffffffffffffffffffff; +uint160 constant MASK_160_BITS = 0x00ffffffffffffffffffffffffffffffffffffffff; + +uint256 constant MASK_BYTES_4 = 0xffffffff00000000000000000000000000000000000000000000000000000000; diff --git a/src/hooks/swap/KSConditionalSwapHook.sol b/src/hooks/swap/KSConditionalSwapHook.sol index add6024..bf93594 100644 --- a/src/hooks/swap/KSConditionalSwapHook.sol +++ b/src/hooks/swap/KSConditionalSwapHook.sol @@ -2,33 +2,40 @@ pragma solidity ^0.8.0; import {IKSSmartIntentHook} from '../../interfaces/hooks/IKSSmartIntentHook.sol'; +import {PackedU128} from '../../types/PackedU128.sol'; import {BaseStatefulHook} from '../base/BaseStatefulHook.sol'; +import {CalldataDecoder} from 'ks-common-sc/src/libraries/calldata/CalldataDecoder.sol'; import {TokenHelper} from 'ks-common-sc/src/libraries/token/TokenHelper.sol'; +import {MerkleProof} from 'openzeppelin-contracts/contracts/utils/cryptography/MerkleProof.sol'; import {ActionData} from '../../types/ActionData.sol'; import {IntentData} from '../../types/IntentData.sol'; contract KSConditionalSwapHook is BaseStatefulHook { using TokenHelper for address; + using CalldataDecoder for bytes; - error InvalidTokenIn(address tokenIn, address actualTokenIn); error AmountInMismatch(uint256 amountIn, uint256 actualAmountIn); - error InvalidSwap(); + error InvalidProof(); + error InvalidTime(uint256 timestamp, uint256 min, uint256 maxT); + error InvalidAmountIn(uint256 amountIn, uint256 minAmountIn, uint256 maxAmountIn); + error InvalidFees( + uint256 srcFeePercent, uint256 dstFeePercent, uint256 maxSrcFee, uint256 maxDstFee + ); + error MaxLeafIndex(); + error InvalidPrice(uint256 price, uint256 minPrice, uint256 maxPrice); + error InvalidSwapLimit(uint256 swapCount, uint256 limit); uint256 public constant DENOMINATOR = 1e18; uint256 public constant PRECISION = 1_000_000; /** * @notice Data structure for conditional swap - * @param swapConditions The swap conditions, a swap will be executed if one of the conditions is met - * @param srcTokens The source tokens - * @param dstTokens The destination tokens + * @param root The Merkle root of all valid swap conditions * @param recipient The recipient of the destination token */ struct SwapHookData { - SwapCondition[][] swapConditions; - address[] srcTokens; - address[] dstTokens; + bytes32 root; address recipient; } @@ -42,16 +49,16 @@ contract KSConditionalSwapHook is BaseStatefulHook { */ struct SwapCondition { uint8 swapLimit; - uint256 timeLimits; - uint256 amountInLimits; - uint256 maxFees; - uint256 priceLimits; + PackedU128 timeLimits; + PackedU128 amountInLimits; + PackedU128 maxFees; + PackedU128 priceLimits; } struct SwapValidationData { - SwapCondition[] swapConditions; + SwapCondition swapCondition; bytes32 intentHash; - uint256 intentIndex; + uint256 leafIndex; address tokenIn; address tokenOut; uint256 amountIn; @@ -64,14 +71,12 @@ contract KSConditionalSwapHook is BaseStatefulHook { /** * @notice Tracks swap execution counts for each condition to enforce swap limits - * @dev Maps intentHash -> intentIndex -> packedIndexes -> packedCounts + * @dev Maps intentHash -> packedIndexes -> packedCounts * Each uint256 stores up to 32 uint8 swap counts (8 bits each), indexed by swapIndexes / 32 * Individual counts are extracted using bit shifts based on swapIndexes % 32 */ - mapping( - bytes32 intentHash - => mapping(uint256 intentIndex => mapping(uint256 swapIndexes => uint256 swapCount)) - ) public swapRecord; + mapping(bytes32 intentHash => mapping(uint256 swapIndexes => uint256 swapCount)) public + swapRecord; constructor(address[] memory initialRouters) BaseStatefulHook(initialRouters) {} @@ -81,7 +86,6 @@ contract KSConditionalSwapHook is BaseStatefulHook { _; } - /// @inheritdoc IKSSmartIntentHook function beforeExecution( bytes32 intentHash, IntentData calldata intentData, @@ -94,25 +98,28 @@ contract KSConditionalSwapHook is BaseStatefulHook { returns (uint256[] memory fees, bytes memory beforeExecutionData) { SwapHookData calldata swapHookData = _decodeHookData(intentData.coreData.hookIntentData); - (uint256 index, uint256 intentSrcFee, uint256 intentDstFee) = - _decodeAndValidateHookActionData(actionData.hookActionData, swapHookData); + ( + bytes32[] calldata proof, + SwapCondition calldata condition, + uint256 leafIndex, + address tokenOut, + uint256 intentSrcFee, + uint256 intentDstFee + ) = _decodeHookActionData(actionData.hookActionData); address tokenIn = intentData.tokenData.erc20Data[actionData.erc20Ids[0]].token; - address tokenOut = swapHookData.dstTokens[index]; - uint256 amountIn = actionData.erc20Amounts[0]; - require( - tokenIn == swapHookData.srcTokens[index], - InvalidTokenIn(tokenIn, swapHookData.srcTokens[index]) - ); + bytes32 leaf = keccak256(abi.encode(leafIndex, tokenIn, tokenOut, condition)); + require(MerkleProof.verifyCalldata(proof, swapHookData.root, leaf), InvalidProof()); + uint256 amountIn = actionData.erc20Amounts[0]; fees = new uint256[](1); fees[0] = (amountIn * intentSrcFee) / PRECISION; beforeExecutionData = abi.encode( SwapValidationData({ - swapConditions: swapHookData.swapConditions[index], + swapCondition: condition, intentHash: intentHash, - intentIndex: index, + leafIndex: leafIndex, tokenIn: tokenIn, tokenOut: tokenOut, amountIn: amountIn, @@ -129,7 +136,6 @@ contract KSConditionalSwapHook is BaseStatefulHook { return (fees, beforeExecutionData); } - /// @inheritdoc IKSSmartIntentHook function afterExecution( bytes32, IntentData calldata intentData, @@ -159,11 +165,13 @@ contract KSConditionalSwapHook is BaseStatefulHook { tokenOut, validationData.recipient, validationData.dstFeePercent ) - validationData.recipientBalanceBefore; - uint256 price = (amountOut * DENOMINATOR) / amountIn; + uint256 fee = (amountOut * validationData.dstFeePercent) / PRECISION; + uint256 price = ((amountOut - fee) * DENOMINATOR) / amountIn; _validateSwapCondition( - validationData.swapConditions, - swapRecord[validationData.intentHash][validationData.intentIndex], + validationData.swapCondition, + validationData.leafIndex, + swapRecord[validationData.intentHash], price, amountIn, validationData.srcFeePercent, @@ -178,10 +186,10 @@ contract KSConditionalSwapHook is BaseStatefulHook { tokens[0] = tokenOut; fees = new uint256[](1); - fees[0] = (amountOut * validationData.dstFeePercent) / PRECISION; + fees[0] = fee; amounts = new uint256[](1); - amounts[0] = amountOut - fees[0]; + amounts[0] = amountOut - fee; recipient = validationData.recipient; @@ -191,89 +199,88 @@ contract KSConditionalSwapHook is BaseStatefulHook { /** * @notice Gets the number of times a specific swap condition has been executed * @param intentHash The hash of the intent - * @param intentIndex The index of the specific intent * @param conditionIndex The index of the swap condition to check * @return The number of times this condition has been executed */ - function getSwapExecutionCount(bytes32 intentHash, uint256 intentIndex, uint256 conditionIndex) + function getSwapExecutionCount(bytes32 intentHash, uint256 conditionIndex) public view returns (uint256) { - uint256 packedValue = swapRecord[intentHash][intentIndex][conditionIndex / 32]; + uint256 packedValue = swapRecord[intentHash][conditionIndex / 32]; uint256 bytePosition = conditionIndex % 32; return uint8(packedValue >> (bytePosition * 8)); } function _validateSwapCondition( - SwapCondition[] calldata swapCondition, + SwapCondition calldata condition, + uint256 index, mapping(uint256 swapIndexes => uint256 swapCounts) storage record, uint256 price, uint256 amountIn, uint256 srcFeePercent, uint256 dstFeePercent ) internal { - for (uint256 i; i < swapCondition.length; ++i) { - SwapCondition calldata condition = swapCondition[i]; - - if ( - block.timestamp < condition.timeLimits >> 128 - || block.timestamp > uint128(condition.timeLimits) - ) { - continue; - } - - if ( - amountIn < condition.amountInLimits >> 128 || amountIn > uint128(condition.amountInLimits) - ) { - continue; - } - - if (srcFeePercent > condition.maxFees >> 128 || dstFeePercent > uint128(condition.maxFees)) { - continue; - } - - if (price < condition.priceLimits >> 128 || price > uint128(condition.priceLimits)) { - continue; - } - - if (!_increaseByOne(record, uint8(i), condition.swapLimit)) { - continue; - } - return; + if ( + block.timestamp < condition.timeLimits.value0() + || block.timestamp > condition.timeLimits.value1() + ) { + revert InvalidTime( + block.timestamp, condition.timeLimits.value0(), condition.timeLimits.value1() + ); + } + + if ( + amountIn < condition.amountInLimits.value0() || amountIn > condition.amountInLimits.value1() + ) { + revert InvalidAmountIn( + amountIn, condition.amountInLimits.value0(), condition.amountInLimits.value1() + ); + } + + if (srcFeePercent > condition.maxFees.value0() || dstFeePercent > condition.maxFees.value1()) { + revert InvalidFees( + srcFeePercent, dstFeePercent, condition.maxFees.value0(), condition.maxFees.value1() + ); } - revert InvalidSwap(); + if (price < condition.priceLimits.value0() || price > condition.priceLimits.value1()) { + revert InvalidPrice(price, condition.priceLimits.value0(), condition.priceLimits.value1()); + } + + (bool success, uint8 swapCount) = _increaseByOne(record, index, condition.swapLimit); + if (!success) { + revert InvalidSwapLimit(swapCount, condition.swapLimit); + } } /** * @notice Increments swap count for a specific condition index - * @dev Uses bit manipulation to efficiently store counts in packed format + * @dev Uses bit manipulation to efficiently store counts in packed format. + * Each uint256 slot holds 32 uint8 counters; the slot is selected by index/32 + * and the byte position within that slot by index%32. * @param record Storage mapping containing packed swap counts * @param index The condition index to increment * @param limit Maximum allowed swaps for this condition - * @return success True if increment was successful (within limit), false otherwise + * @return success True if increment was within limit and the count was stored, false otherwise + * @return swapCount The new swap count after incrementing (or the over-limit value if failed) */ function _increaseByOne( mapping(uint256 packedIndexes => uint256 packedValues) storage record, - uint8 index, + uint256 index, uint8 limit - ) internal returns (bool) { - uint256 packedValue = record[index / 32]; - uint256 bytePosition = index % 32; - - uint8 swapCount = uint8(packedValue >> (bytePosition * 8)) + 1; - - if (swapCount > limit) { - return false; - } - - packedValue += 1 << (bytePosition * 8); + ) internal returns (bool, uint8) { + require(index <= type(uint8).max, MaxLeafIndex()); + uint256 slotKey = index / 32; + uint256 shift = (index % 32) * 8; + uint256 packedValue = record[slotKey]; + uint8 swapCount = uint8(packedValue >> shift) + 1; - record[index / 32] = packedValue; + if (swapCount > limit) return (false, swapCount); - return true; + record[slotKey] = packedValue + (1 << shift); + return (true, swapCount); } function _getRecipientBalance(address tokenOut, address recipient, uint256 feePercent) @@ -287,17 +294,6 @@ contract KSConditionalSwapHook is BaseStatefulHook { return tokenOut.balanceOf(recipient); } - // @dev: equivalent to abi.decode(data, (SwapCondition)) - function _decodeSwapCondition(bytes calldata data) - internal - pure - returns (SwapCondition calldata swapCondition) - { - assembly ('memory-safe') { - swapCondition := data.offset - } - } - // @dev: equivalent to abi.decode(data, (SwapHookData)) function _decodeHookData(bytes calldata data) internal @@ -305,24 +301,39 @@ contract KSConditionalSwapHook is BaseStatefulHook { returns (SwapHookData calldata hookData) { assembly ('memory-safe') { - hookData := add(data.offset, calldataload(data.offset)) + hookData := data.offset } } // @dev: equivalent to abi.decode(data, (uint256, uint256, uint256, uint256)) - function _decodeAndValidateHookActionData(bytes calldata data, SwapHookData calldata swapHookData) + function _decodeHookActionData(bytes calldata data) internal - view - returns (uint256 index, uint256 intentSrcFee, uint256 intentDstFee) + pure + returns ( + bytes32[] calldata proof, + SwapCondition calldata condition, + uint256 leafIndex, + address tokenOut, + uint256 intentSrcFee, + uint256 intentDstFee + ) { - uint256 packedFees; + PackedU128 packedFees; + assembly ('memory-safe') { + leafIndex := calldataload(add(data.offset, 0x20)) + tokenOut := calldataload(add(data.offset, 0x40)) + packedFees := calldataload(add(data.offset, 0x60)) + condition := add(data.offset, 0x80) + } + + (uint256 length, uint256 offset) = data.decodeLengthOffset(0); assembly ('memory-safe') { - index := calldataload(data.offset) - packedFees := calldataload(add(data.offset, 0x20)) + proof.length := length + proof.offset := offset } - intentSrcFee = packedFees >> 128; - intentDstFee = uint128(packedFees); + intentSrcFee = packedFees.value0(); + intentDstFee = packedFees.value1(); } // @dev: equivalent to abi.decode(data, (SwapValidationData)) @@ -332,7 +343,7 @@ contract KSConditionalSwapHook is BaseStatefulHook { returns (SwapValidationData calldata validationData) { assembly ('memory-safe') { - validationData := add(data.offset, calldataload(data.offset)) + validationData := data.offset } } } diff --git a/src/types/PackedU128.sol b/src/types/PackedU128.sol new file mode 100644 index 0000000..385002c --- /dev/null +++ b/src/types/PackedU128.sol @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.0; + +import 'src/constants/BitMask.sol'; + +/** + * @notice two 128-bit values packed into a 256-bit value + * where the first 128 bits are the first value + * and the last 128 bits are the second value. + */ +type PackedU128 is uint256; + +using PackedU128Library for PackedU128 global; + +/** + * @notice pack two 128-bit values into a 256-bit value + * @dev use 256-bit params for versatility + */ +function toPackedU128(uint256 value0, uint256 value1) pure returns (PackedU128 packedU128) { + assembly ('memory-safe') { + packedU128 := or(shl(128, value0), value1) + } +} + +library PackedU128Library { + /// @notice get the first 128 bits of the packed value + function value0(PackedU128 packedU128) internal pure returns (uint128 _value0) { + assembly ('memory-safe') { + _value0 := shr(128, packedU128) + } + } + + /// @notice get the last 128 bits of the packed value + function value1(PackedU128 packedU128) internal pure returns (uint128 _value1) { + assembly ('memory-safe') { + _value1 := and(packedU128, MASK_128_BITS) + } + } + + /// @notice unpack the packed value into two 128-bit values + function unpack(PackedU128 packedU128) internal pure returns (uint128 _value0, uint128 _value1) { + assembly ('memory-safe') { + _value0 := shr(128, packedU128) + _value1 := and(packedU128, MASK_128_BITS) + } + } +} diff --git a/test/ConditionalSwap.t.sol b/test/ConditionalSwap.t.sol index ff31371..426c547 100644 --- a/test/ConditionalSwap.t.sol +++ b/test/ConditionalSwap.t.sol @@ -4,11 +4,14 @@ pragma solidity ^0.8.0; import './Base.t.sol'; import 'src/hooks/swap/KSConditionalSwapHook.sol'; +import 'src/types/PackedU128.sol'; +import 'test/utils/MerkleUtils.sol'; contract ConditionalSwapTest is BaseTest { using SafeERC20 for IERC20; using TokenHelper for address; using ArraysHelper for *; + using MerkleUtils for *; bytes swapdata = hex'00000000000000000000000000000000000000000000000000000000000000200000000000000000000000000f4a1d7fdf4890be35e71f3e0bbc4a0ec377eca3000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000a000000000000000000000000000000000000000000000000000000000000007a000000000000000000000000000000000000000000000000000000000000009e000000000000000000000000000000000000000000000000000000000000006e0000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000000c0000000000000000000000000dac17f958d2ee523a2206206994597c13d831ec70000000000000000000000002260fac5e5542a773aa44fbcfedf7c193bc2c5990000000000000000000000002e234DAe75C793f67A35089C9d99245E1C58470b0000000000000000000000000000000000000000000000000000000067db987b00000000000000000000000000000000000000000000000000000000000006800000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000000030000000000000000000000000000000000000000000000000000000000000060000000000000000000000000000000000000000000000000000000000000022000000000000000000000000000000000000000000000000000000000000004000000000000000000000000000000000000000000000000000000000000000040f59b1df7000000000000000000000000000000000000000000000000000000030000000000000000000000000000000000000000000000000000000000000160000000000000000000000000000000000000000000000000000000000000002000000000000000000000000066a9893cc07d91d95644aedd05d03f95e1dba8af000000000000000000000000000000000000000000000000000000003b9aca00000000000000000000000000000000000022d473030f116ddee9f6b43ac78ba3000000000000000000000000dac17f958d2ee523a2206206994597c13d831ec7000000000000000000000000a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48000000000000000000000000000000000000000000000000000000000000002300000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000012000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000040a9d4c672000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000180000000000000000000000000655edce464cc797526600a462a8154650eee4b77000000000000000000000000000000000000000000000000000000003b9d5f1a000000000000000000000000000000000000000000000000000000003b9d5f1a00000000000000000000000000000000000000000000000006dac07944b594800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48000000000000000000000000c02aaa39b223fe8d0a0e5c4f27ead9083c756cc20000000000000000000000000f4a1d7fdf4890be35e71f3e0bbc4a0ec377eca3000000000000005fa94793ea0000001a371930340fc8fbcc09c409c467db9414000000000000000000000000000000000000000000000000000000000000001bdcffd1bf68c2c17dcf00a25c935efba96aa63b7f75dd43d42b3df2cf7273c2260fb4b38a9db829fbfdabcc6262ac3982f1d31366bfde12a7b67f6f31ba52b2cb0000000000000000000000000000000000000000000000000000000000000040d90ce4910000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000000000000000001000000000000000000000000007f86bf177dd4f3494b841a37e810a34dd56c829b000000000000000000000000c02aaa39b223fe8d0a0e5c4f27ead9083c756cc20000000000000000000000002260fac5e5542a773aa44fbcfedf7c193bc2c5990000000000000000000000000000000000000000000000000000000000000002000000000000000000000000000000000000000000000000000000000000000100000000000000000000000000000000000000000000000006da929a6bb58cc0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000010000000000000000000000000011cbb0000000000000000000000000dac17f958d2ee523a2206206994597c13d831ec70000000000000000000000002260fac5e5542a773aa44fbcfedf7c193bc2c599000000000000000000000000000000000000000000000000000000000000016000000000000000000000000000000000000000000000000000000000000001a000000000000000000000000000000000000000000000000000000000000001e000000000000000000000000000000000000000000000000000000000000002000000000000000000000000002e234DAe75C793f67A35089C9d99245E1C58470b000000000000000000000000000000000000000000000000000000003b9aca00000000000000000000000000000000000000000000000000000000000011c7210000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000022000000000000000000000000000000000000000000000000000000000000000010000000000000000000000000f4a1d7fdf4890be35e71f3e0bbc4a0ec377eca30000000000000000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000000000000000003b9aca00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000024f7b22536f75726365223a22222c22416d6f756e74496e555344223a22313030302e31373135393231313738353037222c22416d6f756e744f7574555344223a22313030302e34373538333032323939323331222c22526566657272616c223a22222c22466c616773223a302c22416d6f756e744f7574223a2231313636323536222c2254696d657374616d70223a313734323434333436392c22526f7574654944223a2263383438663432632d326465322d343364382d623366372d636637366362666430363536222c22496e74656772697479496e666f223a7b224b65794944223a2231222c225369676e6174757265223a224e39426b4975436430714961362f4d64736635717a61657863436c3754413539426e4d70454741437a74432b5875325176494a36444c34476b7075746b636f627554395657357a42744e427a5463736b4e7768434662372f6f52675173676970424e693878716d323869524b3048496834527a70316457512f437737676a58375168653270313853506966492b7550674e5a34647a5a6a4461686b664d416852796d7765783233714942536a65565a6f44483932596a534b4e546176396f2f2f634754766476336a52555538536841763153464b55514b54515470682f4d4f71534f7370646c37306632714155705274566d7739434b4d383347726164506b55546f5854684a2f6c734e784561634267395a37617a363837394d366d31517538465a687237796374367a4242524a774171464e6646436a364b523969307a4e702f665a2b6876394b6970455341666d5078634e4d67773d3d227d7d0000000000000000000000000000000000'; @@ -22,8 +25,12 @@ contract ConditionalSwapTest is BaseTest { uint256 swapAmount = 1_000_000_000; + bytes32 root; + bytes32[] leaves; + KSConditionalSwapHook conditionalSwapHook; - uint256 currentPrice = 11_662_550_000_000; // USDC/BTC denominated by 1e18 + uint256 currentPrice = 1_166_255_000_000_000; // USDC/BTC denominated by 1e18 + KSConditionalSwapHook.SwapCondition defaultCondition; function setUp() public override { super.setUp(); @@ -63,6 +70,8 @@ contract ConditionalSwapTest is BaseTest { uint256 beforeSwapFee = (amountIn * feeBefore) / 1_000_000; uint256 afterSwapFee = (params.returnAmount * feeAfter) / 1_000_000; + bytes32[] memory _memLeaves = leaves; + ActionData memory actionData = _getActionData( intentData.tokenData, abi.encode( @@ -73,7 +82,10 @@ contract ConditionalSwapTest is BaseTest { feeAfter == 0 ? mainAddress : address(router), mainAddress ), - true + true, + 0, + MerkleUtils.getProof(_memLeaves, 0), + defaultCondition ); params.returnAmount = params.returnAmount - afterSwapFee; @@ -84,7 +96,7 @@ contract ConditionalSwapTest is BaseTest { if (feeBefore > maxSrcFee || feeAfter > maxDstFee) { vm.expectRevert( abi.encodeWithSelector( - KSConditionalSwapHook.InvalidSwap.selector, feeBefore, feeAfter, maxSrcFee, maxDstFee + KSConditionalSwapHook.InvalidFees.selector, feeBefore, feeAfter, maxSrcFee, maxDstFee ) ); vm.startPrank(caller); @@ -117,8 +129,15 @@ contract ConditionalSwapTest is BaseTest { _setUpMainAddress(intentData, false); + bytes32[] memory _memLeaves = leaves; + ActionData memory actionData = _getActionData( - intentData.tokenData, _adjustRecipient(feeAfter == 0 ? swapdata2 : swapdata), false + intentData.tokenData, + _adjustRecipient(feeAfter == 0 ? swapdata2 : swapdata), + false, + 0, + MerkleUtils.getProof(_memLeaves, 0), + defaultCondition ); vm.warp(vm.getBlockTimestamp() + 100); @@ -138,24 +157,24 @@ contract ConditionalSwapTest is BaseTest { { condition[0] = KSConditionalSwapHook.SwapCondition({ swapLimit: 1, - timeLimits: ((vm.getBlockTimestamp() - 100) << 128) | (vm.getBlockTimestamp() + 100), - amountInLimits: (swapAmount << 128) | swapAmount, - maxFees: (0 << 128) | type(uint128).max, - priceLimits: (0 << 128) | type(uint128).max + timeLimits: toPackedU128(block.timestamp - 100, block.timestamp + 100), + amountInLimits: toPackedU128(swapAmount, swapAmount), + maxFees: toPackedU128(0, type(uint128).max), + priceLimits: toPackedU128(0, type(uint128).max) }); condition[1] = KSConditionalSwapHook.SwapCondition({ swapLimit: 1, - timeLimits: ((vm.getBlockTimestamp() + 500) << 128) | (vm.getBlockTimestamp() + 700), - amountInLimits: (swapAmount << 128) | swapAmount, - maxFees: (0 << 128) | type(uint128).max, - priceLimits: (0 << 128) | type(uint128).max + timeLimits: toPackedU128(block.timestamp + 500, block.timestamp + 700), + amountInLimits: toPackedU128(swapAmount, swapAmount), + maxFees: toPackedU128(0, type(uint128).max), + priceLimits: toPackedU128(0, type(uint128).max) }); condition[2] = KSConditionalSwapHook.SwapCondition({ swapLimit: 1, - timeLimits: ((vm.getBlockTimestamp() + 1000) << 128) | (vm.getBlockTimestamp() + 1200), - amountInLimits: (swapAmount << 128) | swapAmount, - maxFees: (0 << 128) | type(uint128).max, - priceLimits: (0 << 128) | type(uint128).max + timeLimits: toPackedU128(block.timestamp + 1000, block.timestamp + 1200), + amountInLimits: toPackedU128(swapAmount, swapAmount), + maxFees: toPackedU128(0, type(uint128).max), + priceLimits: toPackedU128(0, type(uint128).max) }); } @@ -169,8 +188,9 @@ contract ConditionalSwapTest is BaseTest { } ActionData memory actionData; + TokenData memory tokenData; + bytes32[] memory _memLeaves = leaves; { - TokenData memory tokenData; tokenData.erc20Data = new ERC20Data[](1); tokenData.erc20Data[0] = ERC20Data({token: tokenIn, amount: swapAmount, permitData: ''}); actionData = _getActionData( @@ -183,7 +203,10 @@ contract ConditionalSwapTest is BaseTest { feeAfter == 0 ? mainAddress : address(router), mainAddress ), - true + true, + 0, + MerkleUtils.getProof(_memLeaves, 0), + condition[0] ); } @@ -192,6 +215,35 @@ contract ConditionalSwapTest is BaseTest { _swap(mode, intentData, actionData, 0, 0); } + // { + // actionData = _getActionData( + // tokenData, + // abi.encode( + // tokenIn, + // tokenOut, + // swapAmount, + // 1000, + // feeAfter == 0 ? mainAddress : address(router), + // mainAddress + // ), + // true, + // 1, + // MerkleUtils.getProof(_memLeaves, 1), + // condition[1] + // ); + // actionData.nonce += 1; + // } + { + actionData.hookActionData = abi.encode( + MerkleUtils.getProof(_memLeaves, 1), + 1, + tokenOut, + toPackedU128(feeBefore, feeAfter), + condition[1] + ); + actionData.nonce += 1; + } + // swap 2 { vm.warp(vm.getBlockTimestamp() + 500); @@ -202,7 +254,16 @@ contract ConditionalSwapTest is BaseTest { // swap 3 { vm.warp(vm.getBlockTimestamp() + 600); - actionData.nonce += 1; + actionData.hookActionData = abi.encode( + MerkleUtils.getProof(_memLeaves, 2), + 2, + tokenOut, + toPackedU128(feeBefore, feeAfter), + condition[2] + ); + actionData.nonce += 2; + } + { _swap(mode, intentData, actionData, 0, 2); } } @@ -215,10 +276,10 @@ contract ConditionalSwapTest is BaseTest { { condition[0] = KSConditionalSwapHook.SwapCondition({ swapLimit: 4, - timeLimits: (0 << 128) | type(uint128).max, - amountInLimits: (swapAmount << 128) | swapAmount, - maxFees: (0 << 128) | type(uint128).max, - priceLimits: ((1_000_000_000_000 - 100) << 128) | (1_000_000_000_000 + 100) + timeLimits: toPackedU128(0, type(uint128).max), + amountInLimits: toPackedU128(swapAmount, swapAmount), + maxFees: toPackedU128(0, type(uint128).max), + priceLimits: toPackedU128(1_000_000_000_000 - 100, 1_000_000_000_000 + 100) }); } @@ -231,8 +292,9 @@ contract ConditionalSwapTest is BaseTest { swapAmount = tmpSwapAmount; } ActionData memory actionData; + TokenData memory tokenData; + bytes32[] memory _memLeaves = leaves; { - TokenData memory tokenData; tokenData.erc20Data = new ERC20Data[](1); tokenData.erc20Data[0] = ERC20Data({token: tokenIn, amount: swapAmount, permitData: ''}); actionData = _getActionData( @@ -245,7 +307,10 @@ contract ConditionalSwapTest is BaseTest { feeAfter == 0 ? mainAddress : address(router), mainAddress ), - true + true, + 0, + MerkleUtils.getProof(_memLeaves, 0), + condition[0] ); } @@ -280,11 +345,11 @@ contract ConditionalSwapTest is BaseTest { uint256 balanceBefore = tokenOut.balanceOf(mainAddress); - assertEq(conditionalSwapHook.getSwapExecutionCount(hash, 0, index), swapCount); + assertEq(conditionalSwapHook.getSwapExecutionCount(hash, index), swapCount); vm.startPrank(caller); router.execute(intentData, dkSignature, guardian, gdSignature, actionData); vm.stopPrank(); - assertEq(conditionalSwapHook.getSwapExecutionCount(hash, 0, index), swapCount + 1); + assertEq(conditionalSwapHook.getSwapExecutionCount(hash, index), swapCount + 1); assertGt(tokenOut.balanceOf(mainAddress), balanceBefore); } @@ -296,24 +361,37 @@ contract ConditionalSwapTest is BaseTest { condition[0] = KSConditionalSwapHook.SwapCondition({ swapLimit: 1, - timeLimits: ((vm.getBlockTimestamp() + 100) << 128) | (vm.getBlockTimestamp() + 1000), - amountInLimits: (0 << 128) | type(uint128).max, - maxFees: (0 << 128) | type(uint128).max, - priceLimits: (0 << 128) | type(uint128).max + timeLimits: toPackedU128(block.timestamp + 100, block.timestamp + 1000), + amountInLimits: toPackedU128(0, type(uint128).max), + maxFees: toPackedU128(0, type(uint128).max), + priceLimits: toPackedU128(0, type(uint128).max) }); IntentData memory intentData = _getIntentData(0, type(uint128).max, condition); _setUpMainAddress(intentData, false); + bytes32[] memory _memLeaves = leaves; ActionData memory actionData = _getActionData( - intentData.tokenData, _adjustRecipient(feeAfter == 0 ? swapdata2 : swapdata), false + intentData.tokenData, + _adjustRecipient(feeAfter == 0 ? swapdata2 : swapdata), + false, + 0, + MerkleUtils.getProof(_memLeaves, 0), + condition[0] ); (address caller, bytes memory dkSignature, bytes memory gdSignature) = _getCallerAndSignatures(mode, intentData, actionData); vm.startPrank(caller); - vm.expectRevert(KSConditionalSwapHook.InvalidSwap.selector); + vm.expectRevert( + abi.encodeWithSelector( + KSConditionalSwapHook.InvalidTime.selector, + block.timestamp, + block.timestamp + 100, + block.timestamp + 1000 + ) + ); router.execute(intentData, dkSignature, guardian, gdSignature, actionData); } @@ -324,25 +402,38 @@ contract ConditionalSwapTest is BaseTest { condition[0] = KSConditionalSwapHook.SwapCondition({ swapLimit: 1, - timeLimits: ((vm.getBlockTimestamp() - 100) << 128) | (vm.getBlockTimestamp() + 100), - amountInLimits: (0 << 128) | type(uint128).max, - maxFees: (0 << 128) | type(uint128).max, - priceLimits: (uint256(type(uint128).max) << 128) | type(uint128).max + timeLimits: toPackedU128(block.timestamp - 100, block.timestamp + 100), + amountInLimits: toPackedU128(0, type(uint128).max), + maxFees: toPackedU128(0, type(uint128).max), + priceLimits: toPackedU128(type(uint128).max, type(uint128).max) }); IntentData memory intentData = _getIntentData(0, type(uint128).max, condition); _setUpMainAddress(intentData, false); + bytes32[] memory _memLeaves = leaves; ActionData memory actionData = _getActionData( - intentData.tokenData, _adjustRecipient(feeAfter == 0 ? swapdata2 : swapdata), false + intentData.tokenData, + _adjustRecipient(feeAfter == 0 ? swapdata2 : swapdata), + false, + 0, + MerkleUtils.getProof(_memLeaves, 0), + condition[0] ); (address caller, bytes memory dkSignature, bytes memory gdSignature) = _getCallerAndSignatures(mode, intentData, actionData); vm.startPrank(caller); - vm.expectRevert(KSConditionalSwapHook.InvalidSwap.selector); + vm.expectRevert( + abi.encodeWithSelector( + KSConditionalSwapHook.InvalidPrice.selector, + currentPrice, + type(uint128).max, + type(uint128).max + ) + ); router.execute(intentData, dkSignature, guardian, gdSignature, actionData); } @@ -355,15 +446,18 @@ contract ConditionalSwapTest is BaseTest { _setUpMainAddress(intentData, false); swapAmount = tmpSwapAmount; ActionData memory actionData; + bytes32[] memory _memLeaves = leaves; { TokenData memory tokenData; tokenData.erc20Data = new ERC20Data[](1); tokenData.erc20Data[0] = ERC20Data({token: tokenIn, amount: swapAmount, permitData: ''}); - actionData = _getActionData(tokenData, '', true); + actionData = _getActionData( + tokenData, '', true, 0, MerkleUtils.getProof(_memLeaves, 0), defaultCondition + ); } bytes32 hash = router.hashTypedIntentData(intentData); - assertEq(conditionalSwapHook.getSwapExecutionCount(hash, 0, 0), 0); + assertEq(conditionalSwapHook.getSwapExecutionCount(hash, 0), 0); { (address caller, bytes memory dkSignature, bytes memory gdSignature) = @@ -374,63 +468,75 @@ contract ConditionalSwapTest is BaseTest { actionData.nonce += 1; (caller, dkSignature, gdSignature) = _getCallerAndSignatures(mode, intentData, actionData); vm.startPrank(caller); - vm.expectRevert(KSConditionalSwapHook.InvalidSwap.selector); + vm.expectRevert(abi.encodeWithSelector(KSConditionalSwapHook.InvalidSwapLimit.selector, 2, 1)); router.execute(intentData, dkSignature, guardian, gdSignature, actionData); } { - assertEq(conditionalSwapHook.getSwapExecutionCount(hash, 0, 0), 1); + assertEq(conditionalSwapHook.getSwapExecutionCount(hash, 0), 1); } } - function testRevert_InvalidTokenIn(uint256 mode) public { + function testRevert_AmountInTooSmallOrTooLarge(uint256 mode, uint128 min, uint128 max) public { mode = bound(mode, 0, 2); + vm.assume(min < max && (min > swapAmount || max < swapAmount)); IntentData memory intentData = - _getIntentData(0, type(uint128).max, new KSConditionalSwapHook.SwapCondition[](0)); - _setUpMainAddress(intentData, false); - intentData.tokenData.erc20Data[0].token = makeAddr('dummy'); + _getIntentData(min, max, new KSConditionalSwapHook.SwapCondition[](0)); _setUpMainAddress(intentData, false); + bytes32[] memory _memLeaves = leaves; ActionData memory actionData = _getActionData( - intentData.tokenData, _adjustRecipient(feeAfter == 0 ? swapdata2 : swapdata), false + intentData.tokenData, + _adjustRecipient(feeAfter == 0 ? swapdata2 : swapdata), + false, + 0, + MerkleUtils.getProof(_memLeaves, 0), + defaultCondition ); - actionData.erc20Ids[0] = 0; - (address caller, bytes memory dkSignature, bytes memory gdSignature) = _getCallerAndSignatures(mode, intentData, actionData); vm.startPrank(caller); vm.expectRevert( - abi.encodeWithSelector( - KSConditionalSwapHook.InvalidTokenIn.selector, makeAddr('dummy'), tokenIn - ) + abi.encodeWithSelector(KSConditionalSwapHook.InvalidAmountIn.selector, swapAmount, min, max) ); router.execute(intentData, dkSignature, guardian, gdSignature, actionData); } - function testRevert_AmountInTooSmallOrTooLarge(uint256 mode, uint128 min, uint128 max) public { + function testRevert_ExceedFeeLimit(uint256 mode) public { + feeBefore = 1000; + feeAfter = 1000; + mode = bound(mode, 0, 2); - vm.assume(min < max && (min > swapAmount || max < swapAmount)); IntentData memory intentData = - _getIntentData(min, max, new KSConditionalSwapHook.SwapCondition[](0)); + _getIntentData(0, type(uint128).max, new KSConditionalSwapHook.SwapCondition[](0)); _setUpMainAddress(intentData, false); + uint256 beforeSwapFee = (swapAmount * feeBefore) / 1_000_000; + + bytes32[] memory _memLeaves = leaves; ActionData memory actionData = _getActionData( - intentData.tokenData, _adjustRecipient(feeAfter == 0 ? swapdata2 : swapdata), false + intentData.tokenData, + abi.encode(tokenIn, tokenOut, swapAmount - beforeSwapFee, 1000, address(router), mainAddress), + true, + 0, + MerkleUtils.getProof(_memLeaves, 0), + defaultCondition ); (address caller, bytes memory dkSignature, bytes memory gdSignature) = _getCallerAndSignatures(mode, intentData, actionData); vm.startPrank(caller); - vm.expectRevert(KSConditionalSwapHook.InvalidSwap.selector); + vm.expectRevert( + abi.encodeWithSelector( + KSConditionalSwapHook.InvalidFees.selector, feeBefore, feeAfter, maxSrcFee, maxDstFee + ) + ); router.execute(intentData, dkSignature, guardian, gdSignature, actionData); } - function testRevert_ExceedFeeLimit(uint256 mode) public { - feeBefore = 1000; - feeAfter = 1000; - + function testRevert_InvalidProof(uint256 mode) public { mode = bound(mode, 0, 2); IntentData memory intentData = _getIntentData(0, type(uint128).max, new KSConditionalSwapHook.SwapCondition[](0)); @@ -438,42 +544,135 @@ contract ConditionalSwapTest is BaseTest { uint256 beforeSwapFee = (swapAmount * feeBefore) / 1_000_000; + bytes32[] memory _memLeaves = leaves; ActionData memory actionData = _getActionData( intentData.tokenData, abi.encode(tokenIn, tokenOut, swapAmount - beforeSwapFee, 1000, address(router), mainAddress), - true + true, + 1, // wrong leaf index + MerkleUtils.getProof(_memLeaves, 0), + defaultCondition ); (address caller, bytes memory dkSignature, bytes memory gdSignature) = _getCallerAndSignatures(mode, intentData, actionData); vm.startPrank(caller); - vm.expectRevert(KSConditionalSwapHook.InvalidSwap.selector); + vm.expectRevert(KSConditionalSwapHook.InvalidProof.selector); router.execute(intentData, dkSignature, guardian, gdSignature, actionData); } - function _getActionData(TokenData memory tokenData, bytes memory actionCalldata, bool swapViaMock) - internal - view - returns (ActionData memory actionData) - { - FeeInfo memory feeInfo; - feeInfo.protocolRecipient = protocolRecipient; - feeInfo.partnerFeeConfigs = new FeeConfig[][](1); - feeInfo.partnerFeeConfigs[0] = _buildPartnersConfigs( - PartnersFeeConfigBuildParams({ - feeModes: [false].toMemoryArray(), - partnerFees: [uint24(1e6)].toMemoryArray(), - partnerRecipients: [partnerRecipient].toMemoryArray() - }) + function testRevert_MaxLeafIndex(uint256 mode) public { + mode = bound(mode, 0, 2); + + uint256 overflowLeafIndex = uint256(type(uint8).max) + 1; // 256 + + KSConditionalSwapHook.SwapCondition memory condition = KSConditionalSwapHook.SwapCondition({ + swapLimit: 1, + timeLimits: toPackedU128(block.timestamp, block.timestamp + 1 days), + amountInLimits: toPackedU128(0, type(uint128).max), + maxFees: toPackedU128(maxSrcFee, maxDstFee), + priceLimits: toPackedU128(0, type(uint128).max) + }); + + KSConditionalSwapHook.SwapCondition[] memory conditions = + new KSConditionalSwapHook.SwapCondition[](1); + conditions[0] = condition; + + _setUpLeaves( + [overflowLeafIndex].toMemoryArray(), + conditions, + [tokenIn].toMemoryArray(), + [tokenOut].toMemoryArray() + ); + + KSConditionalSwapHook.SwapHookData memory hookData; + hookData.root = root; + hookData.recipient = mainAddress; + + IntentCoreData memory coreData = IntentCoreData({ + mainAddress: mainAddress, + signatureVerifier: address(0), + delegatedKey: delegatedPublicKey, + actionContracts: [address(mockActionContract), swapRouter].toMemoryArray(), + actionSelectors: [MockActionContract.swap.selector, IKSSwapRouterV2.swap.selector] + .toMemoryArray(), + hook: address(conditionalSwapHook), + hookIntentData: abi.encode(hookData) + }); + + TokenData memory tokenData; + tokenData.erc20Data = new ERC20Data[](1); + tokenData.erc20Data[0] = ERC20Data({token: tokenIn, amount: swapAmount, permitData: ''}); + + IntentData memory intentData = + IntentData({coreData: coreData, tokenData: tokenData, extraData: ''}); + + _setUpMainAddress(intentData, false); + + bytes32[] memory _memLeaves = leaves; + ActionData memory actionData = _getActionData( + intentData.tokenData, + '', + true, + overflowLeafIndex, + MerkleUtils.getProof(_memLeaves, 0), + condition ); + (address caller, bytes memory dkSignature, bytes memory gdSignature) = + _getCallerAndSignatures(mode, intentData, actionData); + + vm.startPrank(caller); + vm.expectRevert(KSConditionalSwapHook.MaxLeafIndex.selector); + router.execute(intentData, dkSignature, guardian, gdSignature, actionData); + } + + function _setUpLeaves( + uint256[] memory leafIndexes, + KSConditionalSwapHook.SwapCondition[] memory conditions, + address[] memory _tokenIn, + address[] memory _tokenOut + ) internal returns (bytes32[] memory _leaves, bytes32 _root) { + leaves = new bytes32[](leafIndexes.length); + for (uint256 i = 0; i < leafIndexes.length; i++) { + leaves[i] = keccak256(abi.encode(leafIndexes[i], _tokenIn[i], _tokenOut[i], conditions[i])); + } + + root = leaves.getRoot(); + + return (leaves, root); + } + + function _getActionData( + TokenData memory tokenData, + bytes memory actionCalldata, + bool swapViaMock, + uint256 leafIndex, + bytes32[] memory proof, + KSConditionalSwapHook.SwapCondition memory condition + ) internal view returns (ActionData memory actionData) { + uint256 approvalFlags = (1 << (tokenData.erc20Data.length + tokenData.erc721Data.length)) - 1; + + FeeInfo memory _feeInfo; + { + _feeInfo.protocolRecipient = protocolRecipient; + _feeInfo.partnerFeeConfigs = new FeeConfig[][](1); + _feeInfo.partnerFeeConfigs[0] = _buildPartnersConfigs( + PartnersFeeConfigBuildParams({ + feeModes: [false].toMemoryArray(), + partnerFees: [uint24(1e6)].toMemoryArray(), + partnerRecipients: [partnerRecipient].toMemoryArray() + }) + ); + } + actionData = ActionData({ erc20Ids: [uint256(0)].toMemoryArray(), erc20Amounts: [tokenData.erc20Data[0].amount].toMemoryArray(), erc721Ids: new uint256[](0), - feeInfo: feeInfo, - approvalFlags: (1 << (tokenData.erc20Data.length + tokenData.erc721Data.length)) - 1, + feeInfo: _feeInfo, + approvalFlags: approvalFlags, actionSelectorId: swapViaMock ? 0 : 1, actionCalldata: swapViaMock ? (actionCalldata.length == 0 @@ -487,7 +686,9 @@ contract ConditionalSwapTest is BaseTest { ) : actionCalldata) : actionCalldata, - hookActionData: abi.encode(0, (feeBefore << 128) | feeAfter), + hookActionData: abi.encode( + proof, leafIndex, tokenOut, toPackedU128(feeBefore, feeAfter), condition + ), extraData: '', deadline: vm.getBlockTimestamp() + 1 days, nonce: 0 @@ -498,39 +699,58 @@ contract ConditionalSwapTest is BaseTest { uint256 min, uint256 max, KSConditionalSwapHook.SwapCondition[] memory swapConditions - ) internal view returns (IntentData memory intentData) { + ) internal returns (IntentData memory intentData) { KSConditionalSwapHook.SwapHookData memory hookData; - hookData.srcTokens = [tokenIn].toMemoryArray(); - hookData.dstTokens = [tokenOut].toMemoryArray(); - hookData.recipient = mainAddress; - hookData.swapConditions = new KSConditionalSwapHook.SwapCondition[][](1); if (swapConditions.length > 0) { - hookData.swapConditions[0] = swapConditions; + uint256[] memory leafIndexes = new uint256[](swapConditions.length); + address[] memory _tokenIn = new address[](swapConditions.length); + address[] memory _tokenOut = new address[](swapConditions.length); + for (uint256 i = 0; i < swapConditions.length; i++) { + leafIndexes[i] = i; + _tokenIn[i] = tokenIn; + _tokenOut[i] = tokenOut; + } + _setUpLeaves(leafIndexes, swapConditions, _tokenIn, _tokenOut); } else { - hookData.swapConditions[0] = new KSConditionalSwapHook.SwapCondition[](1); - hookData.swapConditions[0][0] = KSConditionalSwapHook.SwapCondition({ + defaultCondition = KSConditionalSwapHook.SwapCondition({ swapLimit: 1, - timeLimits: (vm.getBlockTimestamp() << 128) | (vm.getBlockTimestamp() + 1 days), - amountInLimits: (min << 128) | max, - maxFees: (maxSrcFee << 128) | maxDstFee, - priceLimits: (0 << 128) | type(uint128).max + timeLimits: toPackedU128(block.timestamp, block.timestamp + 1 days), + amountInLimits: toPackedU128(min, max), + maxFees: toPackedU128(maxSrcFee, maxDstFee), + priceLimits: toPackedU128(0, type(uint128).max) }); + + swapConditions = new KSConditionalSwapHook.SwapCondition[](1); + swapConditions[0] = defaultCondition; + + _setUpLeaves( + [uint256(0)].toMemoryArray(), + swapConditions, + [tokenIn].toMemoryArray(), + [tokenOut].toMemoryArray() + ); } - intentData.coreData.mainAddress = mainAddress; - intentData.coreData.signatureVerifier = address(0); - intentData.coreData.delegatedKey = delegatedPublicKey; - intentData.coreData.actionContracts = - [address(mockActionContract), address(swapRouter)].toMemoryArray(); - intentData.coreData.actionSelectors = - [MockActionContract.swap.selector, IKSSwapRouterV2.swap.selector].toMemoryArray(); - intentData.coreData.hook = address(conditionalSwapHook); - intentData.coreData.hookIntentData = abi.encode(hookData); - - intentData.tokenData.erc20Data = new ERC20Data[](1); - intentData.tokenData.erc20Data[0] = - ERC20Data({token: tokenIn, amount: swapAmount, permitData: ''}); + hookData.root = root; + hookData.recipient = mainAddress; + + IntentCoreData memory coreData = IntentCoreData({ + mainAddress: mainAddress, + signatureVerifier: address(0), + delegatedKey: delegatedPublicKey, + actionContracts: [address(mockActionContract), swapRouter].toMemoryArray(), + actionSelectors: [MockActionContract.swap.selector, IKSSwapRouterV2.swap.selector] + .toMemoryArray(), + hook: address(conditionalSwapHook), + hookIntentData: abi.encode(hookData) + }); + + TokenData memory tokenData; + tokenData.erc20Data = new ERC20Data[](1); + tokenData.erc20Data[0] = ERC20Data({token: tokenIn, amount: swapAmount, permitData: ''}); + + intentData = IntentData({coreData: coreData, tokenData: tokenData, extraData: ''}); } function _setUpMainAddress(IntentData memory intentData, bool withSignedIntent) internal { @@ -542,36 +762,6 @@ contract ConditionalSwapTest is BaseTest { vm.stopPrank(); } - function _getActionData(TokenData memory tokenData, bytes memory actionCalldata) - internal - view - returns (ActionData memory actionData) - { - actionData.feeInfo.protocolRecipient = protocolRecipient; - actionData.feeInfo.partnerFeeConfigs = new FeeConfig[][](1); - actionData.feeInfo.partnerFeeConfigs[0] = _buildPartnersConfigs( - PartnersFeeConfigBuildParams({ - feeModes: [false].toMemoryArray(), - partnerFees: [uint24(1e6)].toMemoryArray(), - partnerRecipients: [partnerRecipient].toMemoryArray() - }) - ); - - actionData = ActionData({ - erc20Ids: [uint256(0)].toMemoryArray(), - erc20Amounts: [tokenData.erc20Data[0].amount].toMemoryArray(), - erc721Ids: new uint256[](0), - feeInfo: actionData.feeInfo, - approvalFlags: (1 << (tokenData.erc20Data.length + tokenData.erc721Data.length)) - 1, - actionSelectorId: 0, - actionCalldata: actionCalldata, - hookActionData: abi.encode(0), - extraData: '', - deadline: vm.getBlockTimestamp() + 1 days, - nonce: 0 - }); - } - function _adjustRecipient(bytes memory data) internal view returns (bytes memory) { IKSSwapRouterV2.SwapExecutionParams memory params = abi.decode(data, (IKSSwapRouterV2.SwapExecutionParams)); diff --git a/test/utils/MerkleUtils.sol b/test/utils/MerkleUtils.sol new file mode 100644 index 0000000..1e3d629 --- /dev/null +++ b/test/utils/MerkleUtils.sol @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: GPL-3.0-or-later +pragma solidity ^0.8.0; + +import {Hashes} from 'openzeppelin-contracts/contracts/utils/cryptography/Hashes.sol'; + +library MerkleUtils { + function getRoot(bytes32[] memory leaves) internal pure returns (bytes32) { + while (leaves.length > 1) { + leaves = combine(leaves); + } + return leaves[0]; + } + + function getProof(bytes32[] memory leaves, uint256 node) + internal + pure + returns (bytes32[] memory proof) + { + unchecked { + proof = new bytes32[](log2Up(leaves.length)); + for (uint256 i = 0; i < proof.length; i++) { + if (node & 1 == 1) { + proof[i] = leaves[node - 1]; + } else if (node + 1 < leaves.length) { + proof[i] = leaves[node + 1]; + } + node >>= 1; + leaves = combine(leaves); + } + } + } + + function combine(bytes32[] memory leaves) internal pure returns (bytes32[] memory combined) { + unchecked { + uint256 length = leaves.length; + if (length & 1 == 1) { + combined = new bytes32[](length / 2 + 1); + combined[length / 2] = Hashes.commutativeKeccak256(leaves[length - 1], 0); + } else { + combined = new bytes32[](length / 2); + } + for (uint256 node = 0; node + 1 < length; node += 2) { + combined[node / 2] = Hashes.commutativeKeccak256(leaves[node], leaves[node + 1]); + } + } + } + + /// @dev Returns the log2 of `x`. + /// Equivalent to computing the index of the most significant bit (MSB) of `x`. + /// Returns 0 if `x` is zero. + function log2(uint256 x) internal pure returns (uint256 r) { + /// @solidity memory-safe-assembly + assembly { + r := shl(7, lt(0xffffffffffffffffffffffffffffffff, x)) + r := or(r, shl(6, lt(0xffffffffffffffff, shr(r, x)))) + r := or(r, shl(5, lt(0xffffffff, shr(r, x)))) + r := or(r, shl(4, lt(0xffff, shr(r, x)))) + r := or(r, shl(3, lt(0xff, shr(r, x)))) + // forgefmt: disable-next-item + r := or(r, byte(and(0x1f, shr(shr(r, x), 0x8421084210842108cc6318c6db6d54be)), + 0x0706060506020504060203020504030106050205030304010505030400000000)) + } + } + + /// @dev Returns the log2 of `x`, rounded up. + /// Returns 0 if `x` is zero. + function log2Up(uint256 x) internal pure returns (uint256 r) { + r = log2(x); + /// @solidity memory-safe-assembly + assembly { + r := add(r, lt(shl(r, 1), x)) + } + } +}