Skip to content

Commit 70cd2fb

Browse files
author
Raghuveer Devulapalli
committed
perf: Optimize PQ distance look up table calculation for subvector sizes 4,8 & 16 on AVX-512
Add SIMD fast paths in calculate_partial_sums_dot_f32_512 and calculate_partial_sums_euclidean_f32_512 for the two most common PQ subvector sizes: - size == 4: broadcast a 128-bit query fragment across all four 128-bit lanes of a ZMM register, load four consecutive centroids at once, and reduce each lane independently using two shuffle+add pairs. Produces 4 partial sums per loop iteration instead of 1. - size == 8: broadcast a 256-bit query fragment across both 256-bit halves of a ZMM register, load two consecutive centroids at once, and reduce across 128-bit lanes followed by within-lane shuffles. Produces 2 partial sums per loop iteration instead of 1. - size == 16: query and the centroid fit into a ZMM register, load the query into zmm and then loop over the centroids. Produces one partial sum per loop iteration, but prevents having to load the query multiple times. Both paths fall back to the default way of computing dot_product_f32 / euclidean_f32 in a loop for any tail elements or unsupported sizes.
1 parent 8453d57 commit 70cd2fb

1 file changed

Lines changed: 147 additions & 4 deletions

File tree

jvector-native/src/main/c/jvector_simd.c

Lines changed: 147 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,15 +261,158 @@ JV_INLINE float euclidean_f32(const float* a, int aoffset, const float* b, int b
261261

262262
JV_INLINE void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) {
263263
int codebookBase = codebookIndex * clusterCount;
264-
for (int i = 0; i < clusterCount; i++) {
265-
partialSums[codebookBase + i] = dot_product_f32(codebook, i * size, query, queryOffset, size);
264+
float tempdat[16];
265+
if (size == 4) {
266+
int i = 0;
267+
// use a zmm register to calculate 4 partial sums in parallel:
268+
__m128 q = _mm_loadu_ps(query + queryOffset);
269+
__m512 qq = _mm512_broadcast_f32x4(q); // broadcast 128-bit query to all 4 lanes
270+
for (; i + 4 <= clusterCount; i += 4) {
271+
// load four consecutive centroids from the codebook into zmm
272+
__m512 c = _mm512_loadu_ps(codebook + i * size);
273+
__m512 sum = _mm512_fmadd_ps(c, qq, _mm512_setzero_ps());
274+
// horizontal reduce: within each 128-bit lane independently
275+
// Step 1: swap neighboring elements within 128-bit lanes
276+
__m512 temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(2, 3, 0, 1));
277+
sum = _mm512_add_ps(sum, temp);
278+
// Step 2: swap 32-bit pairs within 128-bit lanes
279+
temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(1, 0, 3, 2));
280+
sum = _mm512_add_ps(sum, temp);
281+
// extract results from position 0 of each 128-bit lane
282+
_mm512_storeu_ps(tempdat, sum);
283+
partialSums[codebookBase + i] = tempdat[0];
284+
partialSums[codebookBase + i + 1] = tempdat[4];
285+
partialSums[codebookBase + i + 2] = tempdat[8];
286+
partialSums[codebookBase + i + 3] = tempdat[12];
287+
}
288+
for (; i < clusterCount; i++) {
289+
partialSums[codebookBase + i] = dot_product_f32(codebook, i * size, query, queryOffset, size);
290+
}
291+
}
292+
else if (size == 8) {
293+
int i = 0;
294+
// use a zmm register to calculate 2 partial sums in parallel:
295+
__m256 q = _mm256_loadu_ps(query + queryOffset);
296+
__m512 qq = _mm512_broadcast_f32x8(q); // 8 cycles, but have to do it just once outside the loop
297+
for (; i + 2 <= clusterCount; i += 2) {
298+
// load two consecutive centroids from the codebook into zmm
299+
__m512 c1 = _mm512_loadu_ps(codebook + i * size);
300+
__m512 sum = _mm512_fmadd_ps(c1, qq, _mm512_setzero_ps());
301+
// horizontal reduce: per 256 bit lanes
302+
// Step 1: swap neighbouring 128 bits and add to sum across lanes
303+
__m512 temp = _mm512_shuffle_f32x4(sum, sum, _MM_SHUFFLE(2, 3, 0, 1)); // swap 128-bit lanes
304+
sum = _mm512_add_ps(sum, temp);
305+
// Step 2: Shuffle and add to sum within lanes
306+
temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(1, 0, 3, 2));
307+
sum = _mm512_add_ps(sum, temp);
308+
// step 3: shuffle neighboring lanes:
309+
temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(2, 3, 0, 1));
310+
sum = _mm512_add_ps(sum, temp);
311+
// extract results: may be there is a better way?
312+
// Store is cheap and loading them should happen from the store buffers, so this may be faster than shuffling and extracting:
313+
// Although its tempting, avoid using vcompress (a high latency instruction)
314+
//_mm512_mask_compressstoreu_ps(ans, 0x8080, sum);
315+
_mm512_storeu_ps(tempdat, sum);
316+
partialSums[codebookBase + i] = tempdat[0];
317+
partialSums[codebookBase + i + 1] = tempdat[8];
318+
}
319+
for (; i < clusterCount; i++) {
320+
partialSums[codebookBase + i] = dot_product_f32(codebook, i * size, query, queryOffset, size);
321+
}
322+
}
323+
else if (size == 16) {
324+
int i = 0;
325+
__m512 q = _mm512_loadu_ps(query + queryOffset);
326+
for (; i < clusterCount; i += 1) {
327+
__m512 c1 = _mm512_loadu_ps(codebook + i * size);
328+
__m512 sum = _mm512_fmadd_ps(c1, c1, _mm512_setzero_ps());
329+
partialSums[codebookBase + i] = _mm512_reduce_add_ps(sum);
330+
}
331+
}
332+
else {
333+
for (int i = 0; i < clusterCount; i++) {
334+
partialSums[codebookBase + i] = dot_product_f32(codebook, i * size, query, queryOffset, size);
335+
}
266336
}
267337
}
268338

269339
JV_INLINE void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) {
270340
int codebookBase = codebookIndex * clusterCount;
271-
for (int i = 0; i < clusterCount; i++) {
272-
partialSums[codebookBase + i] = euclidean_f32(codebook, i * size, query, queryOffset, size);
341+
float tempdat[16];
342+
if (size == 4) {
343+
int i = 0;
344+
// use a zmm register to calculate 4 partial sums in parallel:
345+
__m128 q = _mm_loadu_ps(query + queryOffset);
346+
__m512 qq = _mm512_broadcast_f32x4(q); // broadcast 128-bit query to all 4 lanes
347+
for (; i + 4 <= clusterCount; i += 4) {
348+
// load four consecutive centroids from the codebook into zmm
349+
__m512 c = _mm512_loadu_ps(codebook + i * size);
350+
__m512 diff = _mm512_sub_ps(c, qq);
351+
__m512 sum = _mm512_fmadd_ps(diff, diff, _mm512_setzero_ps());
352+
// horizontal reduce: within each 128-bit lane independently
353+
// Step 1: swap neighboring elements within 128-bit lanes
354+
__m512 temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(2, 3, 0, 1));
355+
sum = _mm512_add_ps(sum, temp);
356+
// Step 2: swap 32-bit pairs within 128-bit lanes
357+
temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(1, 0, 3, 2));
358+
sum = _mm512_add_ps(sum, temp);
359+
// extract results from position 0 of each 128-bit lane
360+
_mm512_storeu_ps(tempdat, sum);
361+
partialSums[codebookBase + i] = tempdat[0];
362+
partialSums[codebookBase + i + 1] = tempdat[4];
363+
partialSums[codebookBase + i + 2] = tempdat[8];
364+
partialSums[codebookBase + i + 3] = tempdat[12];
365+
}
366+
for (; i < clusterCount; i++) {
367+
partialSums[codebookBase + i] = euclidean_f32(codebook, i * size, query, queryOffset, size);
368+
}
369+
}
370+
else if (size == 8) {
371+
int i = 0;
372+
// use a zmm register to calculate 2 partial sums in parallel:
373+
__m256 q = _mm256_loadu_ps(query + queryOffset);
374+
__m512 qq = _mm512_broadcast_f32x8(q); // 8 cycles, but have to do it just once outside the loop
375+
for (; i + 2 <= clusterCount; i += 2) {
376+
// load two consecutive centroids from the codebook into zmm
377+
__m512 c1 = _mm512_loadu_ps(codebook + i * size);
378+
__m512 diff = _mm512_sub_ps(c1, qq);
379+
__m512 sum = _mm512_fmadd_ps(diff, diff, _mm512_setzero_ps());
380+
// horizontal reduce: per 256 bit lanes
381+
// Step 1: swap neighbouring 128 bits and add to sum across lanes
382+
__m512 temp = _mm512_shuffle_f32x4(sum, sum, _MM_SHUFFLE(2, 3, 0, 1)); // swap 128-bit lanes
383+
sum = _mm512_add_ps(sum, temp);
384+
// Step 2: Shuffle and add to sum within lanes
385+
temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(1, 0, 3, 2));
386+
sum = _mm512_add_ps(sum, temp);
387+
// step 3: shuffle neighboring lanes:
388+
temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(2, 3, 0, 1));
389+
sum = _mm512_add_ps(sum, temp);
390+
// extract results: may be there is a better way?
391+
// Store is cheap and loading them should happen from the store buffers, so this may be faster than shuffling and extracting:
392+
// Although its tempting, avoid using vcompress (a high latency instruction)
393+
//_mm512_mask_compressstoreu_ps(ans, 0x8080, sum);
394+
_mm512_storeu_ps(tempdat, sum);
395+
partialSums[codebookBase + i] = tempdat[0];
396+
partialSums[codebookBase + i + 1] = tempdat[8];
397+
}
398+
for (; i < clusterCount; i++) {
399+
partialSums[codebookBase + i] = euclidean_f32(codebook, i * size, query, queryOffset, size);
400+
}
401+
}
402+
else if (size == 16) {
403+
int i = 0;
404+
__m512 q = _mm512_loadu_ps(query + queryOffset);
405+
for (; i < clusterCount; i += 1) {
406+
__m512 c1 = _mm512_loadu_ps(codebook + i * size);
407+
__m512 diff = _mm512_sub_ps(c1, qq);
408+
__m512 sum = _mm512_fmadd_ps(diff, diff, _mm512_setzero_ps());
409+
partialSums[codebookBase + i] = _mm512_reduce_add_ps(sum);
410+
}
411+
}
412+
else {
413+
for (int i = 0; i < clusterCount; i++) {
414+
partialSums[codebookBase + i] = euclidean_f32(codebook, i * size, query, queryOffset, size);
415+
}
273416
}
274417
}
275418

0 commit comments

Comments
 (0)