Wire calculatePartialSums to native SIMD via Panama FFI downcall#651
Wire calculatePartialSums to native SIMD via Panama FFI downcall#651
Conversation
|
Before you submit for review:
If you did not complete any of these, then please explain below. |
70cd2fb to
de4ff79
Compare
jshook
left a comment
There was a problem hiding this comment.
I would like to see much more coverage of these with numerical tests. Are there some already which aren't seen here?
ashkrisk
left a comment
There was a problem hiding this comment.
Looks like an excellent set of optimizations. Left a few comments.
+1 to @jshook's comment about numerical tests. This PR touches almost every single function in the native supporting library, and it would be good to have a set of tests accompanying it, perhaps also in C.
| if [ "$(printf '%s\n' "$MIN_GCC_VERSION" "$CURRENT_GCC_VERSION" | sort -V | head -n1)" = "$MIN_GCC_VERSION" ]; then | ||
| rm -rf ../resources/libjvector.so | ||
| gcc -fPIC -O3 -march=icelake-server -c jvector_simd.c -o jvector_simd.o | ||
| gcc -fPIC -O3 -march=skylake-avx512 -c jvector_simd.c -o jvector_simd.o |
There was a problem hiding this comment.
Is there a strong reason to lower the target micro-architecture version?
There was a problem hiding this comment.
This change addresses an actual bug in our build configuration. We currently compile targeting icelake-server, but at runtime we only check for skylake-avx512. This mismatch allows the compiler to emit Ice Lake–specific instructions that may be executed on a Skylake CPU, which can result in a SIGILL.
Targeting skylake-avx512 resolves this issue and is sufficient for the kernels we currently have; there’s no requirement for Ice Lake–specific features here.
| case 0: | ||
| calculate_partial_sums_euclidean_f32_512(codebook, codebookIndex, size, clusterCount, query, queryOffset, partialSums); | ||
| break; | ||
| case 1: |
There was a problem hiding this comment.
Can we use public enums here? Jextract should automatically make the enums available to the Java code as constants. Alternatively we could skip the parameter-based dispatch altogether and simply expose both versions of the function to Java code.
There was a problem hiding this comment.
Alternatively we could skip the parameter-based dispatch altogether and simply expose both versions of the function to Java code.
Agree. I have updated to use this approach.
| __m512 vaMagnitude = _mm512_setzero_ps(); | ||
| int i = 0; | ||
| int limit = baseOffsetsLength - (baseOffsetsLength % 16); | ||
| const __m512i initialIndexRegister = _mm512_setr_epi32(-16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1); |
There was a problem hiding this comment.
It's good that this isn't a global variable anymore, but given that it's used in multiple places does it make sense to have it as a global constant?
There was a problem hiding this comment.
I generally try to avoid global variables and prefer function‑local const values to keep dependencies explicit and contained, unless avoiding globals would cause significant duplication. In this case, it’s only used in two places, so the duplication is minimal.
There was a problem hiding this comment.
Looks like a lot of functions that are no longer in the public header are still declared here. Should fix itself on re-running jextract.
There was a problem hiding this comment.
Good catch — these should indeed be removed. It looks like we check in NativeSimdOps.java with signatures that are supposed to mirror the public header. That also explains why I keep seeing a locally modified file every time I build this branch.
@jshook @ashkrisk Beyond that, there are only two additional native functions exposed to Java via FFI— |
I think this is covered by my earlier comment, but to clarify: only three native functions are exposed to Java, and this patch modifies just one of them. That function already has strong numerical test coverage, which was actually helpful in catching bugs in an earlier version of the code. |
* Replace icelake-server gcc target with skylake-avx512 in build script * Remove global mutable state: eliminate initialIndexRegister, indexIncrement, maskSeventhBit, maskEighthBit globals and their constructor initializer; move mask constants (maskSeventhBit, maskEighthBit) to local scope inside lookup_partial_sums * Add shared reduce_add_128_ps and reduce_add_256_ps helper functions using proper horizontal-add sequences instead of store-to-array loops * Remove redundant if (length >= N) guards in all SIMD kernels — the loop body already handles the zero-iteration case correctly * Replace store-to-aligned-array horizontal reduction pattern with the new helpers across all 128- and 256-bit dot product and euclidean distance functions * Remove preferred_size parameter from dot_product_f32 and euclidean_f32; always dispatch to AVX-512 when length >= 16 * Standardize inline annotations: replace __attribute__((always_inline)) inline with JV_FINLINE / JV_INLINE macros throughout
…zes 4,8 & 16 on AVX-512 Add SIMD fast paths in calculate_partial_sums_dot_f32_512 and calculate_partial_sums_euclidean_f32_512 for the two most common PQ subvector sizes: - size == 4: broadcast a 128-bit query fragment across all four 128-bit lanes of a ZMM register, load four consecutive centroids at once, and reduce each lane independently using two shuffle+add pairs. Produces 4 partial sums per loop iteration instead of 1. - size == 8: broadcast a 256-bit query fragment across both 256-bit halves of a ZMM register, load two consecutive centroids at once, and reduce across 128-bit lanes followed by within-lane shuffles. Produces 2 partial sums per loop iteration instead of 1. - size == 16: query and the centroid fit into a ZMM register, load the query into zmm and then loop over the centroids. Produces one partial sum per loop iteration, but prevents having to load the query multiple times. Both paths fall back to the default way of computing dot_product_f32 / euclidean_f32 in a loop for any tail elements or unsupported sizes.
159c21f to
0276bf1
Compare
This change uses a native implementation of
calculatePartialSumsto accelerate PQ query scoring.On
ada002-100kwith FUSED_PQ (numPQsubspaces/M =96, JDK build 23.0.1+11-39), it delivers 2–3× higher QPS and 40–65% lower mean latency across common overquery settings. Index build time, disk usage, and heap usage show no meaningful regression. The optimization is isolated to the PQ path; non‑PQ queries are unaffected.Combined QPS and Latency Results (FUSED_PQ)
topK = 10
topK = 100
Summary of changes in this PR:
calculatePartialSumsin NativeVectorUtilSupport to a new Panama FFI downcall for the nativecalculate_partial_sums_f32_512SIMD implementation.