@@ -261,15 +261,158 @@ JV_INLINE float euclidean_f32(const float* a, int aoffset, const float* b, int b
261261
262262JV_INLINE void calculate_partial_sums_dot_f32_512 (const float * codebook , int codebookIndex , int size , int clusterCount , const float * query , int queryOffset , float * partialSums ) {
263263 int codebookBase = codebookIndex * clusterCount ;
264- for (int i = 0 ; i < clusterCount ; i ++ ) {
265- partialSums [codebookBase + i ] = dot_product_f32 (codebook , i * size , query , queryOffset , size );
264+ float tempdat [16 ];
265+ if (size == 4 ) {
266+ int i = 0 ;
267+ // use a zmm register to calculate 4 partial sums in parallel:
268+ __m128 q = _mm_loadu_ps (query + queryOffset );
269+ __m512 qq = _mm512_broadcast_f32x4 (q ); // broadcast 128-bit query to all 4 lanes
270+ for (; i + 4 <= clusterCount ; i += 4 ) {
271+ // load four consecutive centroids from the codebook into zmm
272+ __m512 c = _mm512_loadu_ps (codebook + i * size );
273+ __m512 sum = _mm512_fmadd_ps (c , qq , _mm512_setzero_ps ());
274+ // horizontal reduce: within each 128-bit lane independently
275+ // Step 1: swap neighboring elements within 128-bit lanes
276+ __m512 temp = _mm512_shuffle_ps (sum , sum , _MM_SHUFFLE (2 , 3 , 0 , 1 ));
277+ sum = _mm512_add_ps (sum , temp );
278+ // Step 2: swap 32-bit pairs within 128-bit lanes
279+ temp = _mm512_shuffle_ps (sum , sum , _MM_SHUFFLE (1 , 0 , 3 , 2 ));
280+ sum = _mm512_add_ps (sum , temp );
281+ // extract results from position 0 of each 128-bit lane
282+ _mm512_storeu_ps (tempdat , sum );
283+ partialSums [codebookBase + i ] = tempdat [0 ];
284+ partialSums [codebookBase + i + 1 ] = tempdat [4 ];
285+ partialSums [codebookBase + i + 2 ] = tempdat [8 ];
286+ partialSums [codebookBase + i + 3 ] = tempdat [12 ];
287+ }
288+ for (; i < clusterCount ; i ++ ) {
289+ partialSums [codebookBase + i ] = dot_product_f32 (codebook , i * size , query , queryOffset , size );
290+ }
291+ }
292+ else if (size == 8 ) {
293+ int i = 0 ;
294+ // use a zmm register to calculate 2 partial sums in parallel:
295+ __m256 q = _mm256_loadu_ps (query + queryOffset );
296+ __m512 qq = _mm512_broadcast_f32x8 (q ); // 8 cycles, but have to do it just once outside the loop
297+ for (; i + 2 <= clusterCount ; i += 2 ) {
298+ // load two consecutive centroids from the codebook into zmm
299+ __m512 c1 = _mm512_loadu_ps (codebook + i * size );
300+ __m512 sum = _mm512_fmadd_ps (c1 , qq , _mm512_setzero_ps ());
301+ // horizontal reduce: per 256 bit lanes
302+ // Step 1: swap neighbouring 128 bits and add to sum across lanes
303+ __m512 temp = _mm512_shuffle_f32x4 (sum , sum , _MM_SHUFFLE (2 , 3 , 0 , 1 )); // swap 128-bit lanes
304+ sum = _mm512_add_ps (sum , temp );
305+ // Step 2: Shuffle and add to sum within lanes
306+ temp = _mm512_shuffle_ps (sum , sum , _MM_SHUFFLE (1 , 0 , 3 , 2 ));
307+ sum = _mm512_add_ps (sum , temp );
308+ // step 3: shuffle neighboring lanes:
309+ temp = _mm512_shuffle_ps (sum , sum , _MM_SHUFFLE (2 , 3 , 0 , 1 ));
310+ sum = _mm512_add_ps (sum , temp );
311+ // extract results: may be there is a better way?
312+ // Store is cheap and loading them should happen from the store buffers, so this may be faster than shuffling and extracting:
313+ // Although its tempting, avoid using vcompress (a high latency instruction)
314+ //_mm512_mask_compressstoreu_ps(ans, 0x8080, sum);
315+ _mm512_storeu_ps (tempdat , sum );
316+ partialSums [codebookBase + i ] = tempdat [0 ];
317+ partialSums [codebookBase + i + 1 ] = tempdat [8 ];
318+ }
319+ for (; i < clusterCount ; i ++ ) {
320+ partialSums [codebookBase + i ] = dot_product_f32 (codebook , i * size , query , queryOffset , size );
321+ }
322+ }
323+ else if (size == 16 ) {
324+ int i = 0 ;
325+ __m512 q = _mm512_loadu_ps (query + queryOffset );
326+ for (; i < clusterCount ; i += 1 ) {
327+ __m512 c1 = _mm512_loadu_ps (codebook + i * size );
328+ __m512 sum = _mm512_fmadd_ps (c1 , c1 , _mm512_setzero_ps ());
329+ partialSums [codebookBase + i ] = _mm512_reduce_add_ps (sum );
330+ }
331+ }
332+ else {
333+ for (int i = 0 ; i < clusterCount ; i ++ ) {
334+ partialSums [codebookBase + i ] = dot_product_f32 (codebook , i * size , query , queryOffset , size );
335+ }
266336 }
267337}
268338
269339JV_INLINE void calculate_partial_sums_euclidean_f32_512 (const float * codebook , int codebookIndex , int size , int clusterCount , const float * query , int queryOffset , float * partialSums ) {
270340 int codebookBase = codebookIndex * clusterCount ;
271- for (int i = 0 ; i < clusterCount ; i ++ ) {
272- partialSums [codebookBase + i ] = euclidean_f32 (codebook , i * size , query , queryOffset , size );
341+ float tempdat [16 ];
342+ if (size == 4 ) {
343+ int i = 0 ;
344+ // use a zmm register to calculate 4 partial sums in parallel:
345+ __m128 q = _mm_loadu_ps (query + queryOffset );
346+ __m512 qq = _mm512_broadcast_f32x4 (q ); // broadcast 128-bit query to all 4 lanes
347+ for (; i + 4 <= clusterCount ; i += 4 ) {
348+ // load four consecutive centroids from the codebook into zmm
349+ __m512 c = _mm512_loadu_ps (codebook + i * size );
350+ __m512 diff = _mm512_sub_ps (c , qq );
351+ __m512 sum = _mm512_fmadd_ps (diff , diff , _mm512_setzero_ps ());
352+ // horizontal reduce: within each 128-bit lane independently
353+ // Step 1: swap neighboring elements within 128-bit lanes
354+ __m512 temp = _mm512_shuffle_ps (sum , sum , _MM_SHUFFLE (2 , 3 , 0 , 1 ));
355+ sum = _mm512_add_ps (sum , temp );
356+ // Step 2: swap 32-bit pairs within 128-bit lanes
357+ temp = _mm512_shuffle_ps (sum , sum , _MM_SHUFFLE (1 , 0 , 3 , 2 ));
358+ sum = _mm512_add_ps (sum , temp );
359+ // extract results from position 0 of each 128-bit lane
360+ _mm512_storeu_ps (tempdat , sum );
361+ partialSums [codebookBase + i ] = tempdat [0 ];
362+ partialSums [codebookBase + i + 1 ] = tempdat [4 ];
363+ partialSums [codebookBase + i + 2 ] = tempdat [8 ];
364+ partialSums [codebookBase + i + 3 ] = tempdat [12 ];
365+ }
366+ for (; i < clusterCount ; i ++ ) {
367+ partialSums [codebookBase + i ] = euclidean_f32 (codebook , i * size , query , queryOffset , size );
368+ }
369+ }
370+ else if (size == 8 ) {
371+ int i = 0 ;
372+ // use a zmm register to calculate 2 partial sums in parallel:
373+ __m256 q = _mm256_loadu_ps (query + queryOffset );
374+ __m512 qq = _mm512_broadcast_f32x8 (q ); // 8 cycles, but have to do it just once outside the loop
375+ for (; i + 2 <= clusterCount ; i += 2 ) {
376+ // load two consecutive centroids from the codebook into zmm
377+ __m512 c1 = _mm512_loadu_ps (codebook + i * size );
378+ __m512 diff = _mm512_sub_ps (c1 , qq );
379+ __m512 sum = _mm512_fmadd_ps (diff , diff , _mm512_setzero_ps ());
380+ // horizontal reduce: per 256 bit lanes
381+ // Step 1: swap neighbouring 128 bits and add to sum across lanes
382+ __m512 temp = _mm512_shuffle_f32x4 (sum , sum , _MM_SHUFFLE (2 , 3 , 0 , 1 )); // swap 128-bit lanes
383+ sum = _mm512_add_ps (sum , temp );
384+ // Step 2: Shuffle and add to sum within lanes
385+ temp = _mm512_shuffle_ps (sum , sum , _MM_SHUFFLE (1 , 0 , 3 , 2 ));
386+ sum = _mm512_add_ps (sum , temp );
387+ // step 3: shuffle neighboring lanes:
388+ temp = _mm512_shuffle_ps (sum , sum , _MM_SHUFFLE (2 , 3 , 0 , 1 ));
389+ sum = _mm512_add_ps (sum , temp );
390+ // extract results: may be there is a better way?
391+ // Store is cheap and loading them should happen from the store buffers, so this may be faster than shuffling and extracting:
392+ // Although its tempting, avoid using vcompress (a high latency instruction)
393+ //_mm512_mask_compressstoreu_ps(ans, 0x8080, sum);
394+ _mm512_storeu_ps (tempdat , sum );
395+ partialSums [codebookBase + i ] = tempdat [0 ];
396+ partialSums [codebookBase + i + 1 ] = tempdat [8 ];
397+ }
398+ for (; i < clusterCount ; i ++ ) {
399+ partialSums [codebookBase + i ] = euclidean_f32 (codebook , i * size , query , queryOffset , size );
400+ }
401+ }
402+ else if (size == 16 ) {
403+ int i = 0 ;
404+ __m512 q = _mm512_loadu_ps (query + queryOffset );
405+ for (; i < clusterCount ; i += 1 ) {
406+ __m512 c1 = _mm512_loadu_ps (codebook + i * size );
407+ __m512 diff = _mm512_sub_ps (c1 , qq );
408+ __m512 sum = _mm512_fmadd_ps (diff , diff , _mm512_setzero_ps ());
409+ partialSums [codebookBase + i ] = _mm512_reduce_add_ps (sum );
410+ }
411+ }
412+ else {
413+ for (int i = 0 ; i < clusterCount ; i ++ ) {
414+ partialSums [codebookBase + i ] = euclidean_f32 (codebook , i * size , query , queryOffset , size );
415+ }
273416 }
274417}
275418
0 commit comments