diff --git a/jvector-native/src/main/c/jextract_vector_simd.sh b/jvector-native/src/main/c/jextract_vector_simd.sh index d44d375dd..45767e6e6 100755 --- a/jvector-native/src/main/c/jextract_vector_simd.sh +++ b/jvector-native/src/main/c/jextract_vector_simd.sh @@ -49,7 +49,7 @@ CURRENT_GCC_VERSION=$(gcc -dumpversion) # Check if the current GCC version is greater than or equal to the minimum required version if [ "$(printf '%s\n' "$MIN_GCC_VERSION" "$CURRENT_GCC_VERSION" | sort -V | head -n1)" = "$MIN_GCC_VERSION" ]; then rm -rf ../resources/libjvector.so - gcc -fPIC -O3 -march=icelake-server -c jvector_simd.c -o jvector_simd.o + gcc -fPIC -O3 -march=skylake-avx512 -c jvector_simd.c -o jvector_simd.o gcc -fPIC -O3 -march=x86-64 -c jvector_simd_check.c -o jvector_simd_check.o gcc -shared -o ../resources/libjvector.so jvector_simd_check.o jvector_simd.o @@ -77,4 +77,4 @@ jextract \ jvector_simd.h # Set critical linker option with heap-based segments for all generated methods -sed -i 's/DESC)/DESC, Linker.Option.critical(true))/g' ../java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java \ No newline at end of file +sed -i 's/DESC)/DESC, Linker.Option.critical(true))/g' ../java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java diff --git a/jvector-native/src/main/c/jvector_simd.c b/jvector-native/src/main/c/jvector_simd.c index d9c909c0f..a46fab62e 100644 --- a/jvector-native/src/main/c/jvector_simd.c +++ b/jvector-native/src/main/c/jvector_simd.c @@ -19,23 +19,27 @@ #include #include "jvector_simd.h" -__m512i initialIndexRegister; -__m512i indexIncrement; -__m512i maskSeventhBit; -__m512i maskEighthBit; - -__attribute__((constructor)) -void initialize_constants() { - if (check_compatibility()) { - initialIndexRegister = _mm512_setr_epi32(-16, -15, -14, -13, -12, -11, -10, -9, - -8, -7, -6, -5, -4, -3, -2, -1); - indexIncrement = _mm512_set1_epi32(16); - maskSeventhBit = _mm512_set1_epi16(0x0040); - maskEighthBit = _mm512_set1_epi16(0x0080); - } + +JV_FINLINE float reduce_add_256_ps(__m256 v) { + __m128 lo = _mm256_castps256_ps128(v); + __m128 hi = _mm256_extractf128_ps(v, 1); + __m128 sum128 = _mm_add_ps(lo, hi); + __m128 shuf = _mm_movehdup_ps(sum128); + __m128 sums = _mm_add_ps(sum128, shuf); + shuf = _mm_movehl_ps(shuf, sums); + sums = _mm_add_ss(sums, shuf); + return _mm_cvtss_f32(sums); +} + +JV_FINLINE float reduce_add_128_ps(__m128 v) { + __m128 shuf = _mm_movehdup_ps(v); + __m128 sums = _mm_add_ps(v, shuf); + shuf = _mm_movehl_ps(shuf, sums); + sums = _mm_add_ss(sums, shuf); + return _mm_cvtss_f32(sums); } -float dot_product_f32_64(const float* a, int aoffset, const float* b, int boffset) { +JV_FINLINE float dot_product_f32_64(const float* a, int aoffset, const float* b, int boffset) { __m128 va = _mm_castsi128_ps(_mm_loadl_epi64((__m128i *)(a + aoffset))); __m128 vb = _mm_castsi128_ps(_mm_loadl_epi64((__m128i *)(b + boffset))); @@ -47,7 +51,7 @@ float dot_product_f32_64(const float* a, int aoffset, const float* b, int boffse return result[0] + result[1]; } -float dot_product_f32_128(const float* a, int aoffset, const float* b, int boffset, int length) { +JV_FINLINE float dot_product_f32_128(const float* a, int aoffset, const float* b, int boffset, int length) { float dot = 0.0; int ao = aoffset; int bo = boffset; @@ -55,26 +59,17 @@ float dot_product_f32_128(const float* a, int aoffset, const float* b, int boffs int blim = boffset + length; int simd_length = length - (length % 4); - if (length >= 4) { - __m128 sum = _mm_setzero_ps(); - - for(; ao < aoffset + simd_length; ao += 4, bo += 4) { - // Load float32 - __m128 va = _mm_loadu_ps(a + ao); - __m128 vb = _mm_loadu_ps(b + bo); - - // Multiply and accumulate - sum = _mm_fmadd_ps(va, vb, sum); - } + __m128 sum = _mm_setzero_ps(); - // Horizontal sum of the vector to get dot product - __attribute__((aligned(16))) float result[4]; - _mm_store_ps(result, sum); + for(; ao < aoffset + simd_length; ao += 4, bo += 4) { + // Load float32 + __m128 va = _mm_loadu_ps(a + ao); + __m128 vb = _mm_loadu_ps(b + bo); - for(int i = 0; i < 4; ++i) { - dot += result[i]; - } + // Multiply and accumulate + sum = _mm_fmadd_ps(va, vb, sum); } + dot = reduce_add_128_ps(sum); for (; ao < alim && bo < blim; ao++, bo++) { dot += a[ao] * b[bo]; @@ -83,7 +78,7 @@ float dot_product_f32_128(const float* a, int aoffset, const float* b, int boffs return dot; } -float dot_product_f32_256(const float* a, int aoffset, const float* b, int boffset, int length) { +JV_FINLINE float dot_product_f32_256(const float* a, int aoffset, const float* b, int boffset, int length) { float dot = 0.0; int ao = aoffset; int bo = boffset; @@ -91,27 +86,20 @@ float dot_product_f32_256(const float* a, int aoffset, const float* b, int boffs int blim = boffset + length; int simd_length = length - (length % 8); - if (length >= 8) { - __m256 sum = _mm256_setzero_ps(); - - for(; ao < aoffset + simd_length; ao += 8, bo += 8) { - // Load float32 - __m256 va = _mm256_loadu_ps(a + ao); - __m256 vb = _mm256_loadu_ps(b + bo); - - // Multiply and accumulate - sum = _mm256_fmadd_ps(va, vb, sum); - } + __m256 sum = _mm256_setzero_ps(); - // Horizontal sum of the vector to get dot product - __attribute__((aligned(32))) float result[8]; - _mm256_store_ps(result, sum); + for(; ao < aoffset + simd_length; ao += 8, bo += 8) { + // Load float32 + __m256 va = _mm256_loadu_ps(a + ao); + __m256 vb = _mm256_loadu_ps(b + bo); - for(int i = 0; i < 8; ++i) { - dot += result[i]; - } + // Multiply and accumulate + sum = _mm256_fmadd_ps(va, vb, sum); } + // Horizontal sum of the vector to get dot product + dot = reduce_add_256_ps(sum); + for (; ao < alim && bo < blim; ao++, bo++) { dot += a[ao] * b[bo]; } @@ -119,7 +107,7 @@ float dot_product_f32_256(const float* a, int aoffset, const float* b, int boffs return dot; } -float dot_product_f32_512(const float* a, int aoffset, const float* b, int boffset, int length) { +JV_FINLINE float dot_product_f32_512(const float* a, int aoffset, const float* b, int boffset, int length) { float dot = 0.0; int ao = aoffset; int bo = boffset; @@ -127,21 +115,19 @@ float dot_product_f32_512(const float* a, int aoffset, const float* b, int boffs int blim = boffset + length; int simd_length = length - (length % 16); - if (length >= 16) { - __m512 sum = _mm512_setzero_ps(); - for(; ao < aoffset + simd_length; ao += 16, bo += 16) { - // Load float32 - __m512 va = _mm512_loadu_ps(a + ao); - __m512 vb = _mm512_loadu_ps(b + bo); - - // Multiply and accumulate - sum = _mm512_fmadd_ps(va, vb, sum); - } + __m512 sum = _mm512_setzero_ps(); + for(; ao < aoffset + simd_length; ao += 16, bo += 16) { + // Load float32 + __m512 va = _mm512_loadu_ps(a + ao); + __m512 vb = _mm512_loadu_ps(b + bo); - // Horizontal sum of the vector to get dot product - dot = _mm512_reduce_add_ps(sum); + // Multiply and accumulate + sum = _mm512_fmadd_ps(va, vb, sum); } + // Horizontal sum of the vector to get dot product + dot = _mm512_reduce_add_ps(sum); + for (; ao < alim && bo < blim; ao++, bo++) { dot += a[ao] * b[bo]; } @@ -149,18 +135,18 @@ float dot_product_f32_512(const float* a, int aoffset, const float* b, int boffs return dot; } -float dot_product_f32(int preferred_size, const float* a, int aoffset, const float* b, int boffset, int length) { +JV_FINLINE float dot_product_f32(const float* a, int aoffset, const float* b, int boffset, int length) { if (length == 2) return dot_product_f32_64(a, aoffset, b, boffset); if (length <= 7) return dot_product_f32_128(a, aoffset, b, boffset, length); - return (preferred_size == 512 && length >= 16) + return (length >= 16) ? dot_product_f32_512(a, aoffset, b, boffset, length) : dot_product_f32_256(a, aoffset, b, boffset, length); } -float euclidean_f32_64(const float* a, int aoffset, const float* b, int boffset) { +JV_FINLINE float euclidean_f32_64(const float* a, int aoffset, const float* b, int boffset) { __m128 va = _mm_castsi128_ps(_mm_loadl_epi64((__m128i *)(a + aoffset))); __m128 vb = _mm_castsi128_ps(_mm_loadl_epi64((__m128i *)(b + boffset))); __m128 r = _mm_sub_ps(va, vb); @@ -172,7 +158,7 @@ float euclidean_f32_64(const float* a, int aoffset, const float* b, int boffset) return result[0] + result[1]; } -float euclidean_f32_128(const float* a, int aoffset, const float* b, int boffset, int length) { +JV_FINLINE float euclidean_f32_128(const float* a, int aoffset, const float* b, int boffset, int length) { float squareDistance = 0.0; int ao = aoffset; int bo = boffset; @@ -180,27 +166,20 @@ float euclidean_f32_128(const float* a, int aoffset, const float* b, int boffset int blim = boffset + length; int simd_length = length - (length % 4); - if (length >= 4) { - __m128 sum = _mm_setzero_ps(); - - for(; ao < aoffset + simd_length; ao += 4, bo += 4) { - // Load float32 - __m128 va = _mm_loadu_ps(a + ao); - __m128 vb = _mm_loadu_ps(b + bo); - __m128 diff = _mm_sub_ps(va, vb); - // Multiply and accumulate - sum = _mm_fmadd_ps(diff, diff, sum); - } + __m128 sum = _mm_setzero_ps(); - // Horizontal sum of the vector to get dot product - __attribute__((aligned(16))) float result[4]; - _mm_store_ps(result, sum); - - for(int i = 0; i < 4; ++i) { - squareDistance += result[i]; - } + for(; ao < aoffset + simd_length; ao += 4, bo += 4) { + // Load float32 + __m128 va = _mm_loadu_ps(a + ao); + __m128 vb = _mm_loadu_ps(b + bo); + __m128 diff = _mm_sub_ps(va, vb); + // Multiply and accumulate + sum = _mm_fmadd_ps(diff, diff, sum); } + // Horizontal sum of the vector to get dot product + squareDistance = reduce_add_128_ps(sum); + for (; ao < alim && bo < blim; ao++, bo++) { float diff = a[ao] - b[bo]; squareDistance += diff * diff; @@ -209,7 +188,7 @@ float euclidean_f32_128(const float* a, int aoffset, const float* b, int boffset return squareDistance; } -float euclidean_f32_256(const float* a, int aoffset, const float* b, int boffset, int length) { +JV_FINLINE float euclidean_f32_256(const float* a, int aoffset, const float* b, int boffset, int length) { float squareDistance = 0.0; int ao = aoffset; int bo = boffset; @@ -217,27 +196,20 @@ float euclidean_f32_256(const float* a, int aoffset, const float* b, int boffset int blim = boffset + length; int simd_length = length - (length % 8); - if (length >= 8) { - __m256 sum = _mm256_setzero_ps(); - - for(; ao < aoffset + simd_length; ao += 8, bo += 8) { - // Load float32 - __m256 va = _mm256_loadu_ps(a + ao); - __m256 vb = _mm256_loadu_ps(b + bo); - __m256 diff = _mm256_sub_ps(va, vb); + __m256 sum = _mm256_setzero_ps(); - // Multiply and accumulate - sum = _mm256_fmadd_ps(diff, diff, sum); - } + for(; ao < aoffset + simd_length; ao += 8, bo += 8) { + // Load float32 + __m256 va = _mm256_loadu_ps(a + ao); + __m256 vb = _mm256_loadu_ps(b + bo); + __m256 diff = _mm256_sub_ps(va, vb); - __attribute__((aligned(32))) float result[8]; - _mm256_store_ps(result, sum); - - for(int i = 0; i < 8; ++i) { - squareDistance += result[i]; - } + // Multiply and accumulate + sum = _mm256_fmadd_ps(diff, diff, sum); } + squareDistance = reduce_add_256_ps(sum); + for (; ao < alim && bo < blim; ao++, bo++) { float diff = a[ao] - b[bo]; squareDistance += diff * diff; @@ -246,7 +218,7 @@ float euclidean_f32_256(const float* a, int aoffset, const float* b, int boffset return squareDistance; } -float euclidean_f32_512(const float* a, int aoffset, const float* b, int boffset, int length) { +JV_FINLINE float euclidean_f32_512(const float* a, int aoffset, const float* b, int boffset, int length) { float squareDistance = 0.0; int ao = aoffset; int bo = boffset; @@ -254,22 +226,20 @@ float euclidean_f32_512(const float* a, int aoffset, const float* b, int boffset int blim = boffset + length; int simd_length = length - (length % 16); - if (length >= 16) { - __m512 sum = _mm512_setzero_ps(); - for(; ao < aoffset + simd_length; ao += 16, bo += 16) { - // Load float32 - __m512 va = _mm512_loadu_ps(a + ao); - __m512 vb = _mm512_loadu_ps(b + bo); - __m512 diff = _mm512_sub_ps(va, vb); - - // Multiply and accumulate - sum = _mm512_fmadd_ps(diff, diff, sum); - } - - // Horizontal sum of the vector to get dot product - squareDistance = _mm512_reduce_add_ps(sum); + __m512 sum = _mm512_setzero_ps(); + for(; ao < aoffset + simd_length; ao += 16, bo += 16) { + // Load float32 + __m512 va = _mm512_loadu_ps(a + ao); + __m512 vb = _mm512_loadu_ps(b + bo); + __m512 diff = _mm512_sub_ps(va, vb); + + // Multiply and accumulate + sum = _mm512_fmadd_ps(diff, diff, sum); } + // Horizontal sum of the vector to get dot product + squareDistance = _mm512_reduce_add_ps(sum); + for (; ao < alim && bo < blim; ao++, bo++) { float diff = a[ao] - b[bo]; squareDistance += diff * diff; @@ -278,103 +248,229 @@ float euclidean_f32_512(const float* a, int aoffset, const float* b, int boffset return squareDistance; } -float euclidean_f32(int preferred_size, const float* a, int aoffset, const float* b, int boffset, int length) { +JV_INLINE float euclidean_f32(const float* a, int aoffset, const float* b, int boffset, int length) { if (length == 2) return euclidean_f32_64(a, aoffset, b, boffset); if (length <= 7) return euclidean_f32_128(a, aoffset, b, boffset, length); - return (preferred_size == 512 && length >= 16) + return (length >= 16) ? euclidean_f32_512(a, aoffset, b, boffset, length) : euclidean_f32_256(a, aoffset, b, boffset, length); } -float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned char* baseOffsets, int baseOffsetsOffset, int baseOffsetsLength) { - __m512 sum = _mm512_setzero_ps(); - int i = 0; - int limit = baseOffsetsLength - (baseOffsetsLength % 16); - __m512i indexRegister = initialIndexRegister; - __m512i dataBaseVec = _mm512_set1_epi32(dataBase); - baseOffsets = baseOffsets + baseOffsetsOffset; - - for (; i < limit; i += 16) { - __m128i baseOffsetsRaw = _mm_loadu_si128((__m128i *)(baseOffsets + i)); - __m512i baseOffsetsInt = _mm512_cvtepu8_epi32(baseOffsetsRaw); - // we have base offsets int, which we need to scale to index into data. - // first, we want to initialize a vector with the lane number added as an index - indexRegister = _mm512_add_epi32(indexRegister, indexIncrement); - // then we want to multiply by dataBase - __m512i scale = _mm512_mullo_epi32(indexRegister, dataBaseVec); - // then we want to add the base offsets - __m512i convOffsets = _mm512_add_epi32(scale, baseOffsetsInt); - - __m512 partials = _mm512_i32gather_ps(convOffsets, data, 4); - sum = _mm512_add_ps(sum, partials); +JV_INLINE void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) { + int codebookBase = codebookIndex * clusterCount; + float tempdat[16]; + if (size == 2) { + int i = 0; + // use a zmm register to calculate 8 partial sums in parallel: + __m128 q_lo = _mm_castsi128_ps(_mm_loadl_epi64((__m128i *)(query + queryOffset))); + __m512 qq = _mm512_broadcast_f32x2(q_lo); // broadcast 2 query floats to all 8 x 64-bit positions + for (; i + 8 <= clusterCount; i += 8) { + // load eight consecutive centroids (16 floats) from the codebook into zmm + __m512 c = _mm512_loadu_ps(codebook + i * size); + __m512 prod = _mm512_mul_ps(c, qq); + // horizontal reduce: sum the two products within each 64-bit centroid slot + // shuffle swaps pairs within each 128-bit lane: [a,b,c,d] -> [b,a,d,c] + __m512 temp = _mm512_shuffle_ps(prod, prod, _MM_SHUFFLE(2, 3, 0, 1)); + __m512 sum = _mm512_add_ps(prod, temp); + // results sit at even positions (0,2,4,6,8,10,12,14) + // resgular store and load seem to be better tha vcompress or vpermutex2var for extracting the results + _mm512_storeu_ps(tempdat, sum); + partialSums[codebookBase + i] = tempdat[0]; + partialSums[codebookBase + i + 1] = tempdat[2]; + partialSums[codebookBase + i + 2] = tempdat[4]; + partialSums[codebookBase + i + 3] = tempdat[6]; + partialSums[codebookBase + i + 4] = tempdat[8]; + partialSums[codebookBase + i + 5] = tempdat[10]; + partialSums[codebookBase + i + 6] = tempdat[12]; + partialSums[codebookBase + i + 7] = tempdat[14]; + } + for (; i < clusterCount; i++) { + partialSums[codebookBase + i] = dot_product_f32(codebook, i * size, query, queryOffset, size); + } } - - float res = _mm512_reduce_add_ps(sum); - for (; i < baseOffsetsLength; i++) { - res += data[dataBase * i + baseOffsets[i]]; + else if (size == 4) { + int i = 0; + // use a zmm register to calculate 4 partial sums in parallel: + __m128 q = _mm_loadu_ps(query + queryOffset); + __m512 qq = _mm512_broadcast_f32x4(q); // broadcast 128-bit query to all 4 lanes + for (; i + 4 <= clusterCount; i += 4) { + // load four consecutive centroids from the codebook into zmm + __m512 c = _mm512_loadu_ps(codebook + i * size); + __m512 sum = _mm512_fmadd_ps(c, qq, _mm512_setzero_ps()); + // horizontal reduce: within each 128-bit lane independently + // Step 1: swap neighboring elements within 128-bit lanes + __m512 temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(2, 3, 0, 1)); + sum = _mm512_add_ps(sum, temp); + // Step 2: swap 32-bit pairs within 128-bit lanes + temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(1, 0, 3, 2)); + sum = _mm512_add_ps(sum, temp); + // extract results from position 0 of each 128-bit lane + _mm512_storeu_ps(tempdat, sum); + partialSums[codebookBase + i] = tempdat[0]; + partialSums[codebookBase + i + 1] = tempdat[4]; + partialSums[codebookBase + i + 2] = tempdat[8]; + partialSums[codebookBase + i + 3] = tempdat[12]; + } + for (; i < clusterCount; i++) { + partialSums[codebookBase + i] = dot_product_f32(codebook, i * size, query, queryOffset, size); + } } - - return res; -} - -float pq_decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsOffset, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude) { - __m512 sum = _mm512_setzero_ps(); - __m512 vaMagnitude = _mm512_setzero_ps(); - int i = 0; - int limit = baseOffsetsLength - (baseOffsetsLength % 16); - __m512i indexRegister = initialIndexRegister; - __m512i scale = _mm512_set1_epi32(clusterCount); - baseOffsets = baseOffsets + baseOffsetsOffset; - - - for (; i < limit; i += 16) { - // Load and convert baseOffsets to integers - __m128i baseOffsetsRaw = _mm_loadu_si128((__m128i *)(baseOffsets + i)); - __m512i baseOffsetsInt = _mm512_cvtepu8_epi32(baseOffsetsRaw); - - indexRegister = _mm512_add_epi32(indexRegister, indexIncrement); - // Scale the baseOffsets by the cluster count - __m512i scaledOffsets = _mm512_mullo_epi32(indexRegister, scale); - - // Calculate the final convOffsets by adding the scaled indexes and the base offsets - __m512i convOffsets = _mm512_add_epi32(scaledOffsets, baseOffsetsInt); - - // Gather and sum values for partial sums and a magnitude - __m512 partialSumVals = _mm512_i32gather_ps(convOffsets, partialSums, 4); - sum = _mm512_add_ps(sum, partialSumVals); - - __m512 aMagnitudeVals = _mm512_i32gather_ps(convOffsets, aMagnitude, 4); - vaMagnitude = _mm512_add_ps(vaMagnitude, aMagnitudeVals); + else if (size == 8) { + int i = 0; + // use a zmm register to calculate 2 partial sums in parallel: + __m256 q = _mm256_loadu_ps(query + queryOffset); + __m512 qq = _mm512_broadcast_f32x8(q); // 8 cycles, but have to do it just once outside the loop + for (; i + 2 <= clusterCount; i += 2) { + // load two consecutive centroids from the codebook into zmm + __m512 c1 = _mm512_loadu_ps(codebook + i * size); + __m512 sum = _mm512_fmadd_ps(c1, qq, _mm512_setzero_ps()); + // horizontal reduce: per 256 bit lanes + // Step 1: swap neighbouring 128 bits and add to sum across lanes + __m512 temp = _mm512_shuffle_f32x4(sum, sum, _MM_SHUFFLE(2, 3, 0, 1)); // swap 128-bit lanes + sum = _mm512_add_ps(sum, temp); + // Step 2: Shuffle and add to sum within lanes + temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(1, 0, 3, 2)); + sum = _mm512_add_ps(sum, temp); + // step 3: shuffle neighboring lanes: + temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(2, 3, 0, 1)); + sum = _mm512_add_ps(sum, temp); + // extract results: may be there is a better way? + // Store is cheap and loading them should happen from the store buffers, so this may be faster than shuffling and extracting: + // Although its tempting, avoid using vcompress (a high latency instruction) + //_mm512_mask_compressstoreu_ps(ans, 0x8080, sum); + _mm512_storeu_ps(tempdat, sum); + partialSums[codebookBase + i] = tempdat[0]; + partialSums[codebookBase + i + 1] = tempdat[8]; + } + for (; i < clusterCount; i++) { + partialSums[codebookBase + i] = dot_product_f32(codebook, i * size, query, queryOffset, size); + } } - - // Reduce sums - float sumResult = _mm512_reduce_add_ps(sum); - float aMagnitudeResult = _mm512_reduce_add_ps(vaMagnitude); - - // Handle the remaining elements - for (; i < baseOffsetsLength; i++) { - int offset = clusterCount * i + baseOffsets[i]; - sumResult += partialSums[offset]; - aMagnitudeResult += aMagnitude[offset]; + else if (size == 16) { + int i = 0; + __m512 qq = _mm512_loadu_ps(query + queryOffset); + for (; i < clusterCount; i += 1) { + __m512 c1 = _mm512_loadu_ps(codebook + i * size); + __m512 sum = _mm512_fmadd_ps(qq, c1, _mm512_setzero_ps()); + partialSums[codebookBase + i] = _mm512_reduce_add_ps(sum); + } } - - return sumResult / sqrtf(aMagnitudeResult * bMagnitude); -} - -void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) { - int codebookBase = codebookIndex * clusterCount; - for (int i = 0; i < clusterCount; i++) { - partialSums[codebookBase + i] = dot_product_f32(512, codebook, i * size, query, queryOffset, size); + else { + for (int i = 0; i < clusterCount; i++) { + partialSums[codebookBase + i] = dot_product_f32(codebook, i * size, query, queryOffset, size); + } } } -void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) { +JV_INLINE void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) { int codebookBase = codebookIndex * clusterCount; - for (int i = 0; i < clusterCount; i++) { - partialSums[codebookBase + i] = euclidean_f32(512, codebook, i * size, query, queryOffset, size); + float tempdat[16]; + if (size == 2) { + int i = 0; + // use a zmm register to calculate 8 partial sums in parallel: + __m128 q_lo = _mm_castsi128_ps(_mm_loadl_epi64((__m128i *)(query + queryOffset))); + __m512 qq = _mm512_broadcast_f32x2(q_lo); // broadcast 2 query floats to all 8 x 64-bit positions + for (; i + 8 <= clusterCount; i += 8) { + // load eight consecutive centroids (16 floats) from the codebook into zmm + __m512 c = _mm512_loadu_ps(codebook + i * size); + __m512 diff = _mm512_sub_ps(c, qq); + __m512 sq = _mm512_mul_ps(diff, diff); + // horizontal reduce: sum the two squared diffs within each 64-bit centroid slot + // shuffle swaps pairs within each 128-bit lane: [a,b,c,d] -> [b,a,d,c] + __m512 temp = _mm512_shuffle_ps(sq, sq, _MM_SHUFFLE(2, 3, 0, 1)); + __m512 sum = _mm512_add_ps(sq, temp); + // results sit at even positions (0,2,4,6,8,10,12,14) + _mm512_storeu_ps(tempdat, sum); + partialSums[codebookBase + i] = tempdat[0]; + partialSums[codebookBase + i + 1] = tempdat[2]; + partialSums[codebookBase + i + 2] = tempdat[4]; + partialSums[codebookBase + i + 3] = tempdat[6]; + partialSums[codebookBase + i + 4] = tempdat[8]; + partialSums[codebookBase + i + 5] = tempdat[10]; + partialSums[codebookBase + i + 6] = tempdat[12]; + partialSums[codebookBase + i + 7] = tempdat[14]; + } + for (; i < clusterCount; i++) { + partialSums[codebookBase + i] = euclidean_f32(codebook, i * size, query, queryOffset, size); + } + } + else if (size == 4) { + int i = 0; + // use a zmm register to calculate 4 partial sums in parallel: + __m128 q = _mm_loadu_ps(query + queryOffset); + __m512 qq = _mm512_broadcast_f32x4(q); // broadcast 128-bit query to all 4 lanes + for (; i + 4 <= clusterCount; i += 4) { + // load four consecutive centroids from the codebook into zmm + __m512 c = _mm512_loadu_ps(codebook + i * size); + __m512 diff = _mm512_sub_ps(c, qq); + __m512 sum = _mm512_fmadd_ps(diff, diff, _mm512_setzero_ps()); + // horizontal reduce: within each 128-bit lane independently + // Step 1: swap neighboring elements within 128-bit lanes + __m512 temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(2, 3, 0, 1)); + sum = _mm512_add_ps(sum, temp); + // Step 2: swap 32-bit pairs within 128-bit lanes + temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(1, 0, 3, 2)); + sum = _mm512_add_ps(sum, temp); + // extract results from position 0 of each 128-bit lane + _mm512_storeu_ps(tempdat, sum); + partialSums[codebookBase + i] = tempdat[0]; + partialSums[codebookBase + i + 1] = tempdat[4]; + partialSums[codebookBase + i + 2] = tempdat[8]; + partialSums[codebookBase + i + 3] = tempdat[12]; + } + for (; i < clusterCount; i++) { + partialSums[codebookBase + i] = euclidean_f32(codebook, i * size, query, queryOffset, size); + } + } + else if (size == 8) { + int i = 0; + // use a zmm register to calculate 2 partial sums in parallel: + __m256 q = _mm256_loadu_ps(query + queryOffset); + __m512 qq = _mm512_broadcast_f32x8(q); // 8 cycles, but have to do it just once outside the loop + for (; i + 2 <= clusterCount; i += 2) { + // load two consecutive centroids from the codebook into zmm + __m512 c1 = _mm512_loadu_ps(codebook + i * size); + __m512 diff = _mm512_sub_ps(c1, qq); + __m512 sum = _mm512_fmadd_ps(diff, diff, _mm512_setzero_ps()); + // horizontal reduce: per 256 bit lanes + // Step 1: swap neighbouring 128 bits and add to sum across lanes + __m512 temp = _mm512_shuffle_f32x4(sum, sum, _MM_SHUFFLE(2, 3, 0, 1)); // swap 128-bit lanes + sum = _mm512_add_ps(sum, temp); + // Step 2: Shuffle and add to sum within lanes + temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(1, 0, 3, 2)); + sum = _mm512_add_ps(sum, temp); + // step 3: shuffle neighboring lanes: + temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(2, 3, 0, 1)); + sum = _mm512_add_ps(sum, temp); + // extract results: may be there is a better way? + // Store is cheap and loading them should happen from the store buffers, so this may be faster than shuffling and extracting: + // Although its tempting, avoid using vcompress (a high latency instruction) + //_mm512_mask_compressstoreu_ps(ans, 0x8080, sum); + _mm512_storeu_ps(tempdat, sum); + partialSums[codebookBase + i] = tempdat[0]; + partialSums[codebookBase + i + 1] = tempdat[8]; + } + for (; i < clusterCount; i++) { + partialSums[codebookBase + i] = euclidean_f32(codebook, i * size, query, queryOffset, size); + } + } + else if (size == 16) { + int i = 0; + __m512 qq = _mm512_loadu_ps(query + queryOffset); + for (; i < clusterCount; i += 1) { + __m512 c1 = _mm512_loadu_ps(codebook + i * size); + __m512 diff = _mm512_sub_ps(c1, qq); + __m512 sum = _mm512_fmadd_ps(diff, diff, _mm512_setzero_ps()); + partialSums[codebookBase + i] = _mm512_reduce_add_ps(sum); + } + } + else { + for (int i = 0; i < clusterCount; i++) { + partialSums[codebookBase + i] = euclidean_f32(codebook, i * size, query, queryOffset, size); + } } } @@ -395,7 +491,7 @@ void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codeboo */ -__attribute__((always_inline)) inline __m512i lookup_partial_sums(__m512i shuffle, const char* quantizedPartials, int i) { +JV_FINLINE __m512i lookup_partial_sums(__m512i shuffle, const char* quantizedPartials, int i) { __m512i partialsVecA = _mm512_loadu_epi16(quantizedPartials + i * 512); __m512i partialsVecB = _mm512_loadu_epi16(quantizedPartials + i * 512 + 64); __m512i partialsVecC = _mm512_loadu_epi16(quantizedPartials + i * 512 + 128); @@ -410,6 +506,8 @@ __attribute__((always_inline)) inline __m512i lookup_partial_sums(__m512i shuffl __m512i partialsVecEF = _mm512_permutex2var_epi16(partialsVecE, shuffle, partialsVecF); __m512i partialsVecGH = _mm512_permutex2var_epi16(partialsVecG, shuffle, partialsVecH); + const __m512i maskSeventhBit = _mm512_set1_epi16(0x0040); + const __m512i maskEighthBit = _mm512_set1_epi16(0x0080); __mmask32 maskSeven = _mm512_test_epi16_mask(shuffle, maskSeventhBit); __mmask32 maskEight = _mm512_test_epi16_mask(shuffle, maskEighthBit); __m512i partialsVecABCD = _mm512_mask_blend_epi16(maskSeven, partialsVecAB, partialsVecCD); @@ -420,7 +518,7 @@ __attribute__((always_inline)) inline __m512i lookup_partial_sums(__m512i shuffl } // dequantize a 256-bit vector containing 16 unsigned 16-bit integers into a 512-bit vector containing 16 32-bit floats -__attribute__((always_inline)) inline __m512 dequantize(__m256i quantizedVec, float delta, float base) { +JV_FINLINE __m512 dequantize(__m256i quantizedVec, float delta, float base) { __m512i quantizedVecWidened = _mm512_cvtepu16_epi32(quantizedVec); __m512 floatVec = _mm512_cvtepi32_ps(quantizedVecWidened); __m512 deltaVec = _mm512_set1_ps(delta); @@ -429,7 +527,7 @@ __attribute__((always_inline)) inline __m512 dequantize(__m256i quantizedVec, fl return dequantizedVec; } -void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float minDistance, float* results) { +JV_INLINE void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float minDistance, float* results) { __m512i sum = _mm512_setzero_epi32(); for (int i = 0; i < codebookCount; i++) { @@ -454,7 +552,7 @@ void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char* shuffles, int _mm512_storeu_ps(results + 16, resultsRight); } -void bulk_quantized_shuffle_dot_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float best, float* results) { +JV_INLINE void bulk_quantized_shuffle_dot_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float best, float* results) { __m512i sum = _mm512_setzero_epi32(); for (int i = 0; i < codebookCount; i++) { @@ -478,7 +576,7 @@ void bulk_quantized_shuffle_dot_f32_512(const unsigned char* shuffles, int codeb _mm512_storeu_ps(results + 16, resultsRight); } -void bulk_quantized_shuffle_cosine_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartialSums, float sumDelta, float minDistance, const char* quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, float* results) { +JV_INLINE void bulk_quantized_shuffle_cosine_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartialSums, float sumDelta, float minDistance, const char* quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, float* results) { __m512i sum = _mm512_setzero_epi32(); __m512i magnitude = _mm512_setzero_epi32(); @@ -520,11 +618,11 @@ void bulk_quantized_shuffle_cosine_f32_512(const unsigned char* shuffles, int co } // Partial sum calculations that also record best distances, as this is necessary for Fused ADC quantization -void calculate_partial_sums_best_dot_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances) { +JV_INLINE void calculate_partial_sums_best_dot_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances) { float best = -INFINITY; int codebookBase = codebookIndex * clusterCount; for (int i = 0; i < clusterCount; i++) { - float val = dot_product_f32(512, codebook, i * size, query, queryOffset, size); + float val = dot_product_f32(codebook, i * size, query, queryOffset, size); partialSums[codebookBase + i] = val; if (val > best) { best = val; @@ -533,15 +631,112 @@ void calculate_partial_sums_best_dot_f32_512(const float* codebook, int codebook partialBestDistances[codebookIndex] = best; } -void calculate_partial_sums_best_euclidean_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances) { +JV_INLINE void calculate_partial_sums_best_euclidean_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances) { float best = INFINITY; int codebookBase = codebookIndex * clusterCount; for (int i = 0; i < clusterCount; i++) { - float val = euclidean_f32(512, codebook, i * size, query, queryOffset, size); + float val = euclidean_f32(codebook, i * size, query, queryOffset, size); partialSums[codebookBase + i] = val; if (val < best) { best = val; } } partialBestDistances[codebookIndex] = best; -} \ No newline at end of file +} + +/* List API's exposed to JAVA via FFI here: Do not mark them static or online, + * as they need to be visible to the dynamic linker and we may want to + * benchmark them individually in C. + */ + +float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned char* baseOffsets, int baseOffsetsOffset, int baseOffsetsLength) { + __m512 sum = _mm512_setzero_ps(); + int i = 0; + int limit = baseOffsetsLength - (baseOffsetsLength % 16); + const __m512i initialIndexRegister = _mm512_setr_epi32(-16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1); + const __m512i indexIncrement = _mm512_set1_epi32(16); + __m512i indexRegister = initialIndexRegister; + __m512i dataBaseVec = _mm512_set1_epi32(dataBase); + baseOffsets = baseOffsets + baseOffsetsOffset; + + for (; i < limit; i += 16) { + __m128i baseOffsetsRaw = _mm_loadu_si128((__m128i *)(baseOffsets + i)); + __m512i baseOffsetsInt = _mm512_cvtepu8_epi32(baseOffsetsRaw); + // we have base offsets int, which we need to scale to index into data. + // first, we want to initialize a vector with the lane number added as an index + indexRegister = _mm512_add_epi32(indexRegister, indexIncrement); + // then we want to multiply by dataBase + __m512i scale = _mm512_mullo_epi32(indexRegister, dataBaseVec); + // then we want to add the base offsets + __m512i convOffsets = _mm512_add_epi32(scale, baseOffsetsInt); + + __m512 partials = _mm512_i32gather_ps(convOffsets, data, 4); + sum = _mm512_add_ps(sum, partials); + } + + float res = _mm512_reduce_add_ps(sum); + for (; i < baseOffsetsLength; i++) { + res += data[dataBase * i + baseOffsets[i]]; + } + + return res; +} + +float pq_decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsOffset, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude) { + __m512 sum = _mm512_setzero_ps(); + __m512 vaMagnitude = _mm512_setzero_ps(); + int i = 0; + int limit = baseOffsetsLength - (baseOffsetsLength % 16); + const __m512i initialIndexRegister = _mm512_setr_epi32(-16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1); + const __m512i indexIncrement = _mm512_set1_epi32(16); + __m512i indexRegister = initialIndexRegister; + __m512i scale = _mm512_set1_epi32(clusterCount); + baseOffsets = baseOffsets + baseOffsetsOffset; + + + for (; i < limit; i += 16) { + // Load and convert baseOffsets to integers + __m128i baseOffsetsRaw = _mm_loadu_si128((__m128i *)(baseOffsets + i)); + __m512i baseOffsetsInt = _mm512_cvtepu8_epi32(baseOffsetsRaw); + + indexRegister = _mm512_add_epi32(indexRegister, indexIncrement); + // Scale the baseOffsets by the cluster count + __m512i scaledOffsets = _mm512_mullo_epi32(indexRegister, scale); + + // Calculate the final convOffsets by adding the scaled indexes and the base offsets + __m512i convOffsets = _mm512_add_epi32(scaledOffsets, baseOffsetsInt); + + // Gather and sum values for partial sums and a magnitude + __m512 partialSumVals = _mm512_i32gather_ps(convOffsets, partialSums, 4); + sum = _mm512_add_ps(sum, partialSumVals); + + __m512 aMagnitudeVals = _mm512_i32gather_ps(convOffsets, aMagnitude, 4); + vaMagnitude = _mm512_add_ps(vaMagnitude, aMagnitudeVals); + } + + // Reduce sums + float sumResult = _mm512_reduce_add_ps(sum); + float aMagnitudeResult = _mm512_reduce_add_ps(vaMagnitude); + + // Handle the remaining elements + for (; i < baseOffsetsLength; i++) { + int offset = clusterCount * i + baseOffsets[i]; + sumResult += partialSums[offset]; + aMagnitudeResult += aMagnitude[offset]; + } + + return sumResult / sqrtf(aMagnitudeResult * bMagnitude); +} + +void calculate_partial_sums_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, int similarityFunction, float* partialSums) { + switch (similarityFunction) { + case 0: + calculate_partial_sums_euclidean_f32_512(codebook, codebookIndex, size, clusterCount, query, queryOffset, partialSums); + break; + case 1: + calculate_partial_sums_dot_f32_512(codebook, codebookIndex, size, clusterCount, query, queryOffset, partialSums); + break; + default: + break; + } +} diff --git a/jvector-native/src/main/c/jvector_simd.h b/jvector-native/src/main/c/jvector_simd.h index 55f1a46c1..f39e26f88 100644 --- a/jvector-native/src/main/c/jvector_simd.h +++ b/jvector-native/src/main/c/jvector_simd.h @@ -19,19 +19,13 @@ #ifndef VECTOR_SIMD_DOT_H #define VECTOR_SIMD_DOT_H +#define JV_INLINE static inline +#define JV_FINLINE static inline __attribute__((always_inline)) // check CPU support bool check_compatibility(void); -//F32 -float dot_product_f32(int preferred_size, const float* a, int aoffset, const float* b, int boffset, int length); -float euclidean_f32(int preferred_size, const float* a, int aoffset, const float* b, int boffset, int length); -void bulk_quantized_shuffle_dot_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float minDistance, float* results); -void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float minDistance, float* results); -void bulk_quantized_shuffle_cosine_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartialSums, float sumDelta, float minDistance, const char* quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, float* results); +// APIs exposed to Java via FFI float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned char* baseOffsets, int baseOffsetsOffset, int baseOffsetsLength); float pq_decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsOffset, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude); -void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums); -void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums); -void calculate_partial_sums_best_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances); -void calculate_partial_sums_best_euclidean_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances); -#endif \ No newline at end of file +void calculate_partial_sums_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, int similarityFunction, float* partialSums); +#endif diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java index 48cd7d66e..bf65b181a 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java @@ -104,4 +104,17 @@ public float pqDecodedCosineSimilarity(ByteSequence encoded, int encodedOffse // encoded is a pointer into a PQ chunk - we need to index into it by encodedOffset and provide encodedLength to the native code return NativeSimdOps.pq_decoded_cosine_similarity_f32_512(((MemorySegmentByteSequence) encoded).get(), encodedOffset, encodedLength, clusterCount, ((MemorySegmentVectorFloat) partialSums).get(), ((MemorySegmentVectorFloat) aMagnitude).get(), bMagnitude); } + + @Override + public void calculatePartialSums(VectorFloat codebook, int codebookIndex, int size, int clusterCount, VectorFloat query, int queryOffset, VectorSimilarityFunction vsf, VectorFloat partialSums) { + var nativeCodebook = ((MemorySegmentVectorFloat) codebook).get(); + var nativeQuery = ((MemorySegmentVectorFloat) query).get(); + var nativePartialSums = ((MemorySegmentVectorFloat) partialSums).get(); + int similarityFunction = switch (vsf) { + case EUCLIDEAN -> 0; + case DOT_PRODUCT -> 1; + default -> throw new UnsupportedOperationException("Unsupported similarity function " + vsf); + }; + NativeSimdOps.calculate_partial_sums_f32_512(nativeCodebook, codebookIndex, size, clusterCount, nativeQuery, queryOffset, similarityFunction, nativePartialSums); + } } diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java index 014bdf4b0..5b31c7e02 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java @@ -847,4 +847,36 @@ public static void calculate_partial_sums_best_euclidean_f32_512(MemorySegment c throw new AssertionError("should not reach here", ex$); } } + + private static class calculate_partial_sums_f32_512 { + public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid( + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_INT, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER + ); + public static final MemorySegment ADDR = NativeSimdOps.findOrThrow("calculate_partial_sums_f32_512"); + public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle(ADDR, DESC, Linker.Option.critical(true)); + } + + /** + * {@snippet lang=c : + * void calculate_partial_sums_f32_512(const float *codebook, int codebookIndex, int size, int clusterCount, const float *query, int queryOffset, int similarityFunction, float *partialSums) + * } + */ + public static void calculate_partial_sums_f32_512(MemorySegment codebook, int codebookIndex, int size, int clusterCount, MemorySegment query, int queryOffset, int similarityFunction, MemorySegment partialSums) { + var mh$ = calculate_partial_sums_f32_512.HANDLE; + try { + if (TRACE_DOWNCALLS) { + traceDowncall("calculate_partial_sums_f32_512", codebook, codebookIndex, size, clusterCount, query, queryOffset, similarityFunction, partialSums); + } + mh$.invokeExact(codebook, codebookIndex, size, clusterCount, query, queryOffset, similarityFunction, partialSums); + } catch (Throwable ex$) { + throw new AssertionError("should not reach here", ex$); + } + } } \ No newline at end of file