@@ -228,33 +228,16 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q,
228228 var encodedChunk = getChunk (node2 );
229229 var encodedOffset = getOffsetInChunk (node2 );
230230 // compute the dot product of the query and the codebook centroids corresponding to the encoded points
231- float dp = 0 ;
232- for (int m = 0 ; m < subspaceCount ; m ++) {
233- int centroidIndex = Byte .toUnsignedInt (encodedChunk .get (m + encodedOffset ));
234- int centroidLength = pq .subvectorSizesAndOffsets [m ][0 ];
235- int centroidOffset = pq .subvectorSizesAndOffsets [m ][1 ];
236- dp += VectorUtil .dotProduct (pq .codebooks [m ], centroidIndex * centroidLength , centeredQuery , centroidOffset , centroidLength );
237- }
231+ float dp = VectorUtil .pqScoreDotProduct (pq .codebooks , pq .subvectorSizesAndOffsets , encodedChunk , encodedOffset , centeredQuery , subspaceCount );
238232 // scale to [0, 1]
239233 return (1 + dp ) / 2 ;
240234 };
241235 case COSINE :
242- float norm1 = VectorUtil .dotProduct (centeredQuery , centeredQuery );
243236 return (node2 ) -> {
244237 var encodedChunk = getChunk (node2 );
245238 var encodedOffset = getOffsetInChunk (node2 );
246- // compute the dot product of the query and the codebook centroids corresponding to the encoded points
247- float sum = 0 ;
248- float norm2 = 0 ;
249- for (int m = 0 ; m < subspaceCount ; m ++) {
250- int centroidIndex = Byte .toUnsignedInt (encodedChunk .get (m + encodedOffset ));
251- int centroidLength = pq .subvectorSizesAndOffsets [m ][0 ];
252- int centroidOffset = pq .subvectorSizesAndOffsets [m ][1 ];
253- var codebookOffset = centroidIndex * centroidLength ;
254- sum += VectorUtil .dotProduct (pq .codebooks [m ], codebookOffset , centeredQuery , centroidOffset , centroidLength );
255- norm2 += VectorUtil .dotProduct (pq .codebooks [m ], codebookOffset , pq .codebooks [m ], codebookOffset , centroidLength );
256- }
257- float cosine = sum / (float ) Math .sqrt (norm1 * norm2 );
239+ // compute the cosine of the query and the codebook centroids corresponding to the encoded points
240+ float cosine = VectorUtil .pqScoreCosine (pq .codebooks , pq .subvectorSizesAndOffsets , encodedChunk , encodedOffset , centeredQuery , subspaceCount );
258241 // scale to [0, 1]
259242 return (1 + cosine ) / 2 ;
260243 };
@@ -263,13 +246,7 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q,
263246 var encodedChunk = getChunk (node2 );
264247 var encodedOffset = getOffsetInChunk (node2 );
265248 // compute the euclidean distance between the query and the codebook centroids corresponding to the encoded points
266- float sum = 0 ;
267- for (int m = 0 ; m < subspaceCount ; m ++) {
268- int centroidIndex = Byte .toUnsignedInt (encodedChunk .get (m + encodedOffset ));
269- int centroidLength = pq .subvectorSizesAndOffsets [m ][0 ];
270- int centroidOffset = pq .subvectorSizesAndOffsets [m ][1 ];
271- sum += VectorUtil .squareL2Distance (pq .codebooks [m ], centroidIndex * centroidLength , centeredQuery , centroidOffset , centroidLength );
272- }
249+ float sum = VectorUtil .pqScoreEuclidean (pq .codebooks , pq .subvectorSizesAndOffsets , encodedChunk , encodedOffset , centeredQuery , subspaceCount );
273250 // scale to [0, 1]
274251 return 1 / (1 + sum );
275252 };
@@ -290,40 +267,16 @@ public ScoreFunction.ApproximateScoreFunction diversityFunctionFor(int node1, Ve
290267 var node2Chunk = getChunk (node2 );
291268 var node2Offset = getOffsetInChunk (node2 );
292269 // compute the euclidean distance between the query and the codebook centroids corresponding to the encoded points
293- float dp = 0 ;
294- for (int m = 0 ; m < subspaceCount ; m ++) {
295- int centroidIndex1 = Byte .toUnsignedInt (node1Chunk .get (m + node1Offset ));
296- int centroidIndex2 = Byte .toUnsignedInt (node2Chunk .get (m + node2Offset ));
297- int centroidLength = pq .subvectorSizesAndOffsets [m ][0 ];
298- dp += VectorUtil .dotProduct (pq .codebooks [m ], centroidIndex1 * centroidLength , pq .codebooks [m ], centroidIndex2 * centroidLength , centroidLength );
299- }
270+ float dp = VectorUtil .pqScoreDotProduct (pq .codebooks , pq .subvectorSizesAndOffsets , node1Chunk , node1Offset , node2Chunk , node2Offset , subspaceCount );
300271 // scale to [0, 1]
301272 return (1 + dp ) / 2 ;
302273 };
303274 case COSINE :
304- float norm1 = 0.0f ;
305- for (int m1 = 0 ; m1 < subspaceCount ; m1 ++) {
306- int centroidIndex = Byte .toUnsignedInt (node1Chunk .get (m1 + node1Offset ));
307- int centroidLength = pq .subvectorSizesAndOffsets [m1 ][0 ];
308- var codebookOffset = centroidIndex * centroidLength ;
309- norm1 += VectorUtil .dotProduct (pq .codebooks [m1 ], codebookOffset , pq .codebooks [m1 ], codebookOffset , centroidLength );
310- }
311- final float norm1final = norm1 ;
312275 return (node2 ) -> {
313276 var node2Chunk = getChunk (node2 );
314277 var node2Offset = getOffsetInChunk (node2 );
315278 // compute the dot product of the query and the codebook centroids corresponding to the encoded points
316- float sum = 0 ;
317- float norm2 = 0 ;
318- for (int m = 0 ; m < subspaceCount ; m ++) {
319- int centroidIndex1 = Byte .toUnsignedInt (node1Chunk .get (m + node1Offset ));
320- int centroidIndex2 = Byte .toUnsignedInt (node2Chunk .get (m + node2Offset ));
321- int centroidLength = pq .subvectorSizesAndOffsets [m ][0 ];
322- int codebookOffset = centroidIndex2 * centroidLength ;
323- sum += VectorUtil .dotProduct (pq .codebooks [m ], codebookOffset , pq .codebooks [m ], centroidIndex1 * centroidLength , centroidLength );
324- norm2 += VectorUtil .dotProduct (pq .codebooks [m ], codebookOffset , pq .codebooks [m ], codebookOffset , centroidLength );
325- }
326- float cosine = sum / (float ) Math .sqrt (norm1final * norm2 );
279+ float cosine = VectorUtil .pqScoreCosine (pq .codebooks , pq .subvectorSizesAndOffsets , node1Chunk , node1Offset , node2Chunk , node2Offset , subspaceCount );
327280 // scale to [0, 1]
328281 return (1 + cosine ) / 2 ;
329282 };
@@ -332,13 +285,7 @@ public ScoreFunction.ApproximateScoreFunction diversityFunctionFor(int node1, Ve
332285 var node2Chunk = getChunk (node2 );
333286 var node2Offset = getOffsetInChunk (node2 );
334287 // compute the euclidean distance between the query and the codebook centroids corresponding to the encoded points
335- float sum = 0 ;
336- for (int m = 0 ; m < subspaceCount ; m ++) {
337- int centroidIndex1 = Byte .toUnsignedInt (node1Chunk .get (m + node1Offset ));
338- int centroidIndex2 = Byte .toUnsignedInt (node2Chunk .get (m + node2Offset ));
339- int centroidLength = pq .subvectorSizesAndOffsets [m ][0 ];
340- sum += VectorUtil .squareL2Distance (pq .codebooks [m ], centroidIndex1 * centroidLength , pq .codebooks [m ], centroidIndex2 * centroidLength , centroidLength );
341- }
288+ float sum = VectorUtil .pqScoreEuclidean (pq .codebooks , pq .subvectorSizesAndOffsets , node1Chunk , node1Offset , node2Chunk , node2Offset , subspaceCount );
342289 // scale to [0, 1]
343290 return 1 / (1 + sum );
344291 };
0 commit comments