Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/constants/BitMask.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;

uint256 constant MASK_1_BIT = 0x1;
uint256 constant MASK_8_BITS = 0xff;
uint256 constant MASK_24_BITS = 0xffffff;
uint256 constant MASK_127_BITS = 0x7fffffffffffffffffffffffffffffff;
uint256 constant MASK_128_BITS = 0xffffffffffffffffffffffffffffffff;
uint160 constant MASK_160_BITS = 0x00ffffffffffffffffffffffffffffffffffffffff;

uint256 constant MASK_BYTES_4 = 0xffffffff00000000000000000000000000000000000000000000000000000000;
186 changes: 99 additions & 87 deletions src/hooks/swap/KSConditionalSwapHook.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,28 @@
pragma solidity ^0.8.0;

import {IKSSmartIntentHook} from '../../interfaces/hooks/IKSSmartIntentHook.sol';
import {PackedU128} from '../../types/PackedU128.sol';
import {BaseStatefulHook} from '../base/BaseStatefulHook.sol';
import {CalldataDecoder} from 'ks-common-sc/src/libraries/calldata/CalldataDecoder.sol';
import {TokenHelper} from 'ks-common-sc/src/libraries/token/TokenHelper.sol';
import {MerkleProof} from 'openzeppelin-contracts/contracts/utils/cryptography/MerkleProof.sol';

import {ActionData} from '../../types/ActionData.sol';
import {IntentData} from '../../types/IntentData.sol';

contract KSConditionalSwapHook is BaseStatefulHook {
using TokenHelper for address;
using CalldataDecoder for bytes;

error InvalidTokenIn(address tokenIn, address actualTokenIn);
error AmountInMismatch(uint256 amountIn, uint256 actualAmountIn);
error InvalidSwap();
error InvalidProof();
error InvalidTime(uint256 timestamp, uint256 min, uint256 maxT);
error InvalidAmountIn(uint256 amountIn, uint256 minAmountIn, uint256 maxAmountIn);
error InvalidFees(
uint256 srcFeePercent, uint256 dstFeePercent, uint256 maxSrcFee, uint256 maxDstFee
);
error InvalidPrice(uint256 price, uint256 minPrice, uint256 maxPrice);
error InvalidSwapLimit(uint256 swapCount, uint256 limit);

uint256 public constant DENOMINATOR = 1e18;
uint256 public constant PRECISION = 1_000_000;
Expand All @@ -26,9 +36,7 @@ contract KSConditionalSwapHook is BaseStatefulHook {
* @param recipient The recipient of the destination token
*/
struct SwapHookData {
SwapCondition[][] swapConditions;
address[] srcTokens;
address[] dstTokens;
bytes32 root;
address recipient;
}

Expand All @@ -42,16 +50,16 @@ contract KSConditionalSwapHook is BaseStatefulHook {
*/
struct SwapCondition {
uint8 swapLimit;
uint256 timeLimits;
uint256 amountInLimits;
uint256 maxFees;
uint256 priceLimits;
PackedU128 timeLimits;
PackedU128 amountInLimits;
PackedU128 maxFees;
PackedU128 priceLimits;
}

struct SwapValidationData {
SwapCondition[] swapConditions;
SwapCondition swapCondition;
bytes32 intentHash;
uint256 intentIndex;
uint256 leafIndex;
address tokenIn;
address tokenOut;
uint256 amountIn;
Expand All @@ -64,14 +72,12 @@ contract KSConditionalSwapHook is BaseStatefulHook {

/**
* @notice Tracks swap execution counts for each condition to enforce swap limits
* @dev Maps intentHash -> intentIndex -> packedIndexes -> packedCounts
* @dev Maps intentHash -> packedIndexes -> packedCounts
* Each uint256 stores up to 32 uint8 swap counts (8 bits each), indexed by swapIndexes / 32
* Individual counts are extracted using bit shifts based on swapIndexes % 32
*/
mapping(
bytes32 intentHash
=> mapping(uint256 intentIndex => mapping(uint256 swapIndexes => uint256 swapCount))
) public swapRecord;
mapping(bytes32 intentHash => mapping(uint256 swapIndexes => uint256 swapCount)) public
swapRecord;

constructor(address[] memory initialRouters) BaseStatefulHook(initialRouters) {}

Expand All @@ -81,7 +87,6 @@ contract KSConditionalSwapHook is BaseStatefulHook {
_;
}

/// @inheritdoc IKSSmartIntentHook
function beforeExecution(
bytes32 intentHash,
IntentData calldata intentData,
Expand All @@ -94,25 +99,28 @@ contract KSConditionalSwapHook is BaseStatefulHook {
returns (uint256[] memory fees, bytes memory beforeExecutionData)
{
SwapHookData calldata swapHookData = _decodeHookData(intentData.coreData.hookIntentData);
(uint256 index, uint256 intentSrcFee, uint256 intentDstFee) =
_decodeAndValidateHookActionData(actionData.hookActionData, swapHookData);
(
bytes32[] calldata proof,
SwapCondition calldata condition,
uint256 leafIndex,
address tokenOut,
uint256 intentSrcFee,
uint256 intentDstFee
) = _decodeHookActionData(actionData.hookActionData);

address tokenIn = intentData.tokenData.erc20Data[actionData.erc20Ids[0]].token;
address tokenOut = swapHookData.dstTokens[index];
uint256 amountIn = actionData.erc20Amounts[0];

require(
tokenIn == swapHookData.srcTokens[index],
InvalidTokenIn(tokenIn, swapHookData.srcTokens[index])
);
bytes32 leaf = keccak256(abi.encode(leafIndex, tokenIn, tokenOut, condition));
require(MerkleProof.verifyCalldata(proof, swapHookData.root, leaf), InvalidProof());

uint256 amountIn = actionData.erc20Amounts[0];
fees = new uint256[](1);
fees[0] = (amountIn * intentSrcFee) / PRECISION;
beforeExecutionData = abi.encode(
SwapValidationData({
swapConditions: swapHookData.swapConditions[index],
swapCondition: condition,
intentHash: intentHash,
intentIndex: index,
leafIndex: leafIndex,
tokenIn: tokenIn,
tokenOut: tokenOut,
amountIn: amountIn,
Expand All @@ -129,7 +137,6 @@ contract KSConditionalSwapHook is BaseStatefulHook {
return (fees, beforeExecutionData);
}

/// @inheritdoc IKSSmartIntentHook
function afterExecution(
bytes32,
IntentData calldata intentData,
Expand Down Expand Up @@ -162,8 +169,9 @@ contract KSConditionalSwapHook is BaseStatefulHook {
uint256 price = (amountOut * DENOMINATOR) / amountIn;
Comment thread
minhtr09 marked this conversation as resolved.
Outdated

_validateSwapCondition(
validationData.swapConditions,
swapRecord[validationData.intentHash][validationData.intentIndex],
validationData.swapCondition,
validationData.leafIndex,
swapRecord[validationData.intentHash],
price,
amountIn,
validationData.srcFeePercent,
Expand Down Expand Up @@ -191,60 +199,60 @@ contract KSConditionalSwapHook is BaseStatefulHook {
/**
* @notice Gets the number of times a specific swap condition has been executed
* @param intentHash The hash of the intent
* @param intentIndex The index of the specific intent
* @param conditionIndex The index of the swap condition to check
* @return The number of times this condition has been executed
*/
function getSwapExecutionCount(bytes32 intentHash, uint256 intentIndex, uint256 conditionIndex)
function getSwapExecutionCount(bytes32 intentHash, uint256 conditionIndex)
public
view
returns (uint256)
{
uint256 packedValue = swapRecord[intentHash][intentIndex][conditionIndex / 32];
uint256 packedValue = swapRecord[intentHash][conditionIndex / 32];
uint256 bytePosition = conditionIndex % 32;

return uint8(packedValue >> (bytePosition * 8));
}

function _validateSwapCondition(
SwapCondition[] calldata swapCondition,
SwapCondition calldata condition,
uint256 index,
mapping(uint256 swapIndexes => uint256 swapCounts) storage record,
uint256 price,
uint256 amountIn,
uint256 srcFeePercent,
uint256 dstFeePercent
) internal {
for (uint256 i; i < swapCondition.length; ++i) {
SwapCondition calldata condition = swapCondition[i];

if (
block.timestamp < condition.timeLimits >> 128
|| block.timestamp > uint128(condition.timeLimits)
) {
continue;
}

if (
amountIn < condition.amountInLimits >> 128 || amountIn > uint128(condition.amountInLimits)
) {
continue;
}

if (srcFeePercent > condition.maxFees >> 128 || dstFeePercent > uint128(condition.maxFees)) {
continue;
}

if (price < condition.priceLimits >> 128 || price > uint128(condition.priceLimits)) {
continue;
}

if (!_increaseByOne(record, uint8(i), condition.swapLimit)) {
continue;
}
return;
if (
block.timestamp < condition.timeLimits.value0()
|| block.timestamp > condition.timeLimits.value1()
) {
revert InvalidTime(
block.timestamp, condition.timeLimits.value0(), condition.timeLimits.value1()
);
}

if (
amountIn < condition.amountInLimits.value0() || amountIn > condition.amountInLimits.value1()
) {
revert InvalidAmountIn(
amountIn, condition.amountInLimits.value0(), condition.amountInLimits.value1()
);
}

if (srcFeePercent > condition.maxFees.value0() || dstFeePercent > condition.maxFees.value1()) {
revert InvalidFees(
srcFeePercent, dstFeePercent, condition.maxFees.value0(), condition.maxFees.value1()
);
}

if (price < condition.priceLimits.value0() || price > condition.priceLimits.value1()) {
revert InvalidPrice(price, condition.priceLimits.value0(), condition.priceLimits.value1());
}

revert InvalidSwap();
(bool success, uint8 swapCount) = _increaseByOne(record, uint8(index), condition.swapLimit);
Comment thread
minhtr09 marked this conversation as resolved.
Outdated
if (!success) {
revert InvalidSwapLimit(swapCount, condition.swapLimit);
}
}

/**
Expand All @@ -259,21 +267,21 @@ contract KSConditionalSwapHook is BaseStatefulHook {
mapping(uint256 packedIndexes => uint256 packedValues) storage record,
uint8 index,
uint8 limit
) internal returns (bool) {
) internal returns (bool, uint8) {
uint256 packedValue = record[index / 32];
uint256 bytePosition = index % 32;

uint8 swapCount = uint8(packedValue >> (bytePosition * 8)) + 1;

if (swapCount > limit) {
return false;
return (false, swapCount);
}

packedValue += 1 << (bytePosition * 8);

record[index / 32] = packedValue;

return true;
return (true, swapCount);
Comment thread
minhtr09 marked this conversation as resolved.
Outdated
}

function _getRecipientBalance(address tokenOut, address recipient, uint256 feePercent)
Expand All @@ -287,42 +295,46 @@ contract KSConditionalSwapHook is BaseStatefulHook {
return tokenOut.balanceOf(recipient);
}

// @dev: equivalent to abi.decode(data, (SwapCondition))
function _decodeSwapCondition(bytes calldata data)
internal
pure
returns (SwapCondition calldata swapCondition)
{
assembly ('memory-safe') {
swapCondition := data.offset
}
}

// @dev: equivalent to abi.decode(data, (SwapHookData))
function _decodeHookData(bytes calldata data)
internal
pure
returns (SwapHookData calldata hookData)
{
assembly ('memory-safe') {
hookData := add(data.offset, calldataload(data.offset))
hookData := data.offset
}
}

// @dev: equivalent to abi.decode(data, (uint256, uint256, uint256, uint256))
function _decodeAndValidateHookActionData(bytes calldata data, SwapHookData calldata swapHookData)
function _decodeHookActionData(bytes calldata data)
internal
view
returns (uint256 index, uint256 intentSrcFee, uint256 intentDstFee)
pure
returns (
bytes32[] calldata proof,
SwapCondition calldata condition,
uint256 leafIndex,
address tokenOut,
uint256 intentSrcFee,
uint256 intentDstFee
)
{
uint256 packedFees;
PackedU128 packedFees;
assembly ('memory-safe') {
leafIndex := calldataload(add(data.offset, 0x20))
tokenOut := calldataload(add(data.offset, 0x40))
packedFees := calldataload(add(data.offset, 0x60))
condition := add(data.offset, 0x80)
}

(uint256 length, uint256 offset) = data.decodeLengthOffset(0);
assembly ('memory-safe') {
index := calldataload(data.offset)
packedFees := calldataload(add(data.offset, 0x20))
proof.length := length
proof.offset := offset
}

intentSrcFee = packedFees >> 128;
intentDstFee = uint128(packedFees);
intentSrcFee = packedFees.value0();
intentDstFee = packedFees.value1();
}

// @dev: equivalent to abi.decode(data, (SwapValidationData))
Expand All @@ -332,7 +344,7 @@ contract KSConditionalSwapHook is BaseStatefulHook {
returns (SwapValidationData calldata validationData)
{
assembly ('memory-safe') {
validationData := add(data.offset, calldataload(data.offset))
validationData := data.offset
}
}
}
Loading
Loading