Skip to content

Commit 5969a42

Browse files
authored
Rebased reduce_lanes changes with latest main (#611)
1 parent 8e85edb commit 5969a42

File tree

6 files changed

+1357
-60
lines changed

6 files changed

+1357
-60
lines changed
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/*
2+
* Copyright DataStax, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package io.github.jbellis.jvector.bench;
17+
18+
import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues;
19+
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
20+
import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
21+
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
22+
import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider;
23+
import io.github.jbellis.jvector.quantization.MutablePQVectors;
24+
import io.github.jbellis.jvector.quantization.PQVectors;
25+
import io.github.jbellis.jvector.quantization.ProductQuantization;
26+
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
27+
import io.github.jbellis.jvector.vector.VectorizationProvider;
28+
import io.github.jbellis.jvector.vector.types.VectorFloat;
29+
import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
30+
import org.openjdk.jmh.annotations.*;
31+
import org.openjdk.jmh.infra.Blackhole;
32+
import org.slf4j.Logger;
33+
import org.slf4j.LoggerFactory;
34+
35+
import java.io.IOException;
36+
import java.util.ArrayList;
37+
import java.util.List;
38+
import java.util.concurrent.TimeUnit;
39+
40+
/**
41+
* Benchmark that compares the distance calculation of mutable Product Quantized vectors vs full precision vectors.
42+
*/
43+
@BenchmarkMode(Mode.AverageTime)
44+
@OutputTimeUnit(TimeUnit.MILLISECONDS)
45+
@State(Scope.Thread)
46+
@Fork(value = 1, jvmArgsAppend = {"--add-modules=jdk.incubator.vector", "--enable-preview", "-Djvector.experimental.enable_native_vectorization=false"})
47+
@Warmup(iterations = 2)
48+
@Measurement(iterations = 3)
49+
@Threads(1)
50+
public class PQDistanceCalculationMutableVectorBenchmark {
51+
private static final Logger log = LoggerFactory.getLogger(PQDistanceCalculationMutableVectorBenchmark.class);
52+
private static final VectorTypeSupport VECTOR_TYPE_SUPPORT = VectorizationProvider.getInstance().getVectorTypeSupport();
53+
54+
private List<VectorFloat<?>> vectors;
55+
private PQVectors pqVectors;
56+
private List<VectorFloat<?>> queryVectors;
57+
private ProductQuantization pq;
58+
private BuildScoreProvider buildScoreProvider;
59+
60+
@Param({"1536"})
61+
private int dimension;
62+
63+
@Param({"10000"})
64+
private int vectorCount;
65+
66+
@Param({"100"})
67+
private int queryCount;
68+
69+
@Param({ "16","32", "64","96", "192"})
70+
private int M; // Number of subspaces for PQ
71+
72+
@Param
73+
private VectorSimilarityFunction vsf;
74+
75+
@Setup
76+
public void setup() throws IOException {
77+
log.info("Creating dataset with dimension: {}, vector count: {}, query count: {}", dimension, vectorCount, queryCount);
78+
79+
// Create random vectors
80+
vectors = new ArrayList<>(vectorCount);
81+
for (int i = 0; i < vectorCount; i++) {
82+
vectors.add(createRandomVector(dimension));
83+
}
84+
85+
// Create query vectors
86+
queryVectors = new ArrayList<>(queryCount);
87+
for (int i = 0; i < queryCount; i++) {
88+
queryVectors.add(createRandomVector(dimension));
89+
}
90+
91+
RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(vectors, dimension);
92+
// Create Mutable PQ vectors
93+
pq = ProductQuantization.compute(ravv, M, 256, true);
94+
pqVectors = new MutablePQVectors(pq);
95+
// build the index vector-at-a-time (on disk)
96+
for (int ordinal = 0; ordinal < vectors.size(); ordinal++)
97+
{
98+
VectorFloat<?> v = vectors.get(ordinal);
99+
// compress the new vector and add it to the PQVectors
100+
((MutablePQVectors)pqVectors).encodeAndSet(ordinal, v);
101+
}
102+
buildScoreProvider = BuildScoreProvider.pqBuildScoreProvider(vsf, pqVectors);
103+
log.info("Created dataset with dimension: {}, vector count: {}, query count: {}", dimension, vectorCount, queryCount);
104+
}
105+
106+
@Benchmark
107+
public void scoreCalculation(Blackhole blackhole) {
108+
float totalSimilarity = 0;
109+
110+
for (VectorFloat<?> query : queryVectors) {
111+
112+
ScoreFunction.ApproximateScoreFunction asf = pqVectors.scoreFunctionFor(query, vsf);
113+
for (int i = 0; i < vectorCount; i++) {
114+
float similarity = asf.similarityTo(i);
115+
totalSimilarity += similarity;
116+
}
117+
}
118+
119+
blackhole.consume(totalSimilarity);
120+
}
121+
122+
@Benchmark
123+
public void diversityCalculation(Blackhole blackhole) {
124+
float totalSimilarity = 0;
125+
126+
for (int q = 0; q < queryCount; q++) {
127+
for (int i = 0; i < vectorCount; i++) {
128+
final ScoreFunction sf = buildScoreProvider.diversityProviderFor(i).scoreFunction();
129+
float similarity = sf.similarityTo(q);
130+
totalSimilarity += similarity;
131+
}
132+
}
133+
134+
blackhole.consume(totalSimilarity);
135+
}
136+
137+
private VectorFloat<?> createRandomVector(int dimension) {
138+
VectorFloat<?> vector = VECTOR_TYPE_SUPPORT.createFloatVector(dimension);
139+
for (int i = 0; i < dimension; i++) {
140+
vector.set(i, (float) Math.random());
141+
}
142+
return vector;
143+
}
144+
}

jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java

Lines changed: 7 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -228,33 +228,16 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q,
228228
var encodedChunk = getChunk(node2);
229229
var encodedOffset = getOffsetInChunk(node2);
230230
// compute the dot product of the query and the codebook centroids corresponding to the encoded points
231-
float dp = 0;
232-
for (int m = 0; m < subspaceCount; m++) {
233-
int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset));
234-
int centroidLength = pq.subvectorSizesAndOffsets[m][0];
235-
int centroidOffset = pq.subvectorSizesAndOffsets[m][1];
236-
dp += VectorUtil.dotProduct(pq.codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength);
237-
}
231+
float dp = VectorUtil.pqScoreDotProduct(pq.codebooks, pq.subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount);
238232
// scale to [0, 1]
239233
return (1 + dp) / 2;
240234
};
241235
case COSINE:
242-
float norm1 = VectorUtil.dotProduct(centeredQuery, centeredQuery);
243236
return (node2) -> {
244237
var encodedChunk = getChunk(node2);
245238
var encodedOffset = getOffsetInChunk(node2);
246-
// compute the dot product of the query and the codebook centroids corresponding to the encoded points
247-
float sum = 0;
248-
float norm2 = 0;
249-
for (int m = 0; m < subspaceCount; m++) {
250-
int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset));
251-
int centroidLength = pq.subvectorSizesAndOffsets[m][0];
252-
int centroidOffset = pq.subvectorSizesAndOffsets[m][1];
253-
var codebookOffset = centroidIndex * centroidLength;
254-
sum += VectorUtil.dotProduct(pq.codebooks[m], codebookOffset, centeredQuery, centroidOffset, centroidLength);
255-
norm2 += VectorUtil.dotProduct(pq.codebooks[m], codebookOffset, pq.codebooks[m], codebookOffset, centroidLength);
256-
}
257-
float cosine = sum / (float) Math.sqrt(norm1 * norm2);
239+
// compute the cosine of the query and the codebook centroids corresponding to the encoded points
240+
float cosine = VectorUtil.pqScoreCosine(pq.codebooks, pq.subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount);
258241
// scale to [0, 1]
259242
return (1 + cosine) / 2;
260243
};
@@ -263,13 +246,7 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q,
263246
var encodedChunk = getChunk(node2);
264247
var encodedOffset = getOffsetInChunk(node2);
265248
// compute the euclidean distance between the query and the codebook centroids corresponding to the encoded points
266-
float sum = 0;
267-
for (int m = 0; m < subspaceCount; m++) {
268-
int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset));
269-
int centroidLength = pq.subvectorSizesAndOffsets[m][0];
270-
int centroidOffset = pq.subvectorSizesAndOffsets[m][1];
271-
sum += VectorUtil.squareL2Distance(pq.codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength);
272-
}
249+
float sum = VectorUtil.pqScoreEuclidean(pq.codebooks, pq.subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount);
273250
// scale to [0, 1]
274251
return 1 / (1 + sum);
275252
};
@@ -290,40 +267,16 @@ public ScoreFunction.ApproximateScoreFunction diversityFunctionFor(int node1, Ve
290267
var node2Chunk = getChunk(node2);
291268
var node2Offset = getOffsetInChunk(node2);
292269
// compute the euclidean distance between the query and the codebook centroids corresponding to the encoded points
293-
float dp = 0;
294-
for (int m = 0; m < subspaceCount; m++) {
295-
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
296-
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
297-
int centroidLength = pq.subvectorSizesAndOffsets[m][0];
298-
dp += VectorUtil.dotProduct(pq.codebooks[m], centroidIndex1 * centroidLength, pq.codebooks[m], centroidIndex2 * centroidLength, centroidLength);
299-
}
270+
float dp = VectorUtil.pqScoreDotProduct(pq.codebooks, pq.subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount);
300271
// scale to [0, 1]
301272
return (1 + dp) / 2;
302273
};
303274
case COSINE:
304-
float norm1 = 0.0f;
305-
for (int m1 = 0; m1 < subspaceCount; m1++) {
306-
int centroidIndex = Byte.toUnsignedInt(node1Chunk.get(m1 + node1Offset));
307-
int centroidLength = pq.subvectorSizesAndOffsets[m1][0];
308-
var codebookOffset = centroidIndex * centroidLength;
309-
norm1 += VectorUtil.dotProduct(pq.codebooks[m1], codebookOffset, pq.codebooks[m1], codebookOffset, centroidLength);
310-
}
311-
final float norm1final = norm1;
312275
return (node2) -> {
313276
var node2Chunk = getChunk(node2);
314277
var node2Offset = getOffsetInChunk(node2);
315278
// compute the dot product of the query and the codebook centroids corresponding to the encoded points
316-
float sum = 0;
317-
float norm2 = 0;
318-
for (int m = 0; m < subspaceCount; m++) {
319-
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
320-
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
321-
int centroidLength = pq.subvectorSizesAndOffsets[m][0];
322-
int codebookOffset = centroidIndex2 * centroidLength;
323-
sum += VectorUtil.dotProduct(pq.codebooks[m], codebookOffset, pq.codebooks[m], centroidIndex1 * centroidLength, centroidLength);
324-
norm2 += VectorUtil.dotProduct(pq.codebooks[m], codebookOffset, pq.codebooks[m], codebookOffset, centroidLength);
325-
}
326-
float cosine = sum / (float) Math.sqrt(norm1final * norm2);
279+
float cosine = VectorUtil.pqScoreCosine(pq.codebooks, pq.subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount);
327280
// scale to [0, 1]
328281
return (1 + cosine) / 2;
329282
};
@@ -332,13 +285,7 @@ public ScoreFunction.ApproximateScoreFunction diversityFunctionFor(int node1, Ve
332285
var node2Chunk = getChunk(node2);
333286
var node2Offset = getOffsetInChunk(node2);
334287
// compute the euclidean distance between the query and the codebook centroids corresponding to the encoded points
335-
float sum = 0;
336-
for (int m = 0; m < subspaceCount; m++) {
337-
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
338-
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
339-
int centroidLength = pq.subvectorSizesAndOffsets[m][0];
340-
sum += VectorUtil.squareL2Distance(pq.codebooks[m], centroidIndex1 * centroidLength, pq.codebooks[m], centroidIndex2 * centroidLength, centroidLength);
341-
}
288+
float sum = VectorUtil.pqScoreEuclidean(pq.codebooks, pq.subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount);
342289
// scale to [0, 1]
343290
return 1 / (1 + sum);
344291
};

jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,4 +547,90 @@ public float nvqUniformLoss(VectorFloat<?> vector, float minValue, float maxValu
547547
return squaredSum;
548548
}
549549

550+
@Override
551+
public float pqScoreDotProduct(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> node1Chunk, int node1Offset, ByteSequence<?> node2Chunk, int node2Offset, int subspaceCount) {
552+
float dp = 0;
553+
for (int m = 0; m < subspaceCount; m++) {
554+
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
555+
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
556+
int centroidLength = subvectorSizesAndOffsets[m][0];
557+
dp += dotProduct(codebooks[m], centroidIndex1 * centroidLength, codebooks[m], centroidIndex2 * centroidLength, centroidLength);
558+
}
559+
return dp;
560+
}
561+
562+
563+
@Override
564+
public float pqScoreCosine(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> node1Chunk, int node1Offset, ByteSequence<?> node2Chunk, int node2Offset, int subspaceCount) {
565+
float sum = 0;
566+
float aMagnitude = 0;
567+
float bMagnitude = 0;
568+
for (int m = 0; m < subspaceCount; m++) {
569+
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
570+
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
571+
int centroidLength = subvectorSizesAndOffsets[m][0];
572+
sum += dotProduct(codebooks[m], centroidIndex1 * centroidLength, codebooks[m], centroidIndex2 * centroidLength, centroidLength);
573+
aMagnitude += dotProduct(codebooks[m], centroidIndex1 * centroidLength, codebooks[m], centroidIndex1 * centroidLength, centroidLength);
574+
bMagnitude += dotProduct(codebooks[m], centroidIndex2 * centroidLength, codebooks[m], centroidIndex2 * centroidLength, centroidLength);
575+
}
576+
return (float)(sum / Math.sqrt(aMagnitude * bMagnitude));
577+
}
578+
579+
@Override
580+
public float pqScoreEuclidean(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> node1Chunk, int node1Offset, ByteSequence<?> node2Chunk, int node2Offset, int subspaceCount) {
581+
float sum = 0;
582+
for (int m = 0; m < subspaceCount; m++) {
583+
int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset));
584+
int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset));
585+
int centroidLength = subvectorSizesAndOffsets[m][0];
586+
587+
sum += squareDistance(codebooks[m], centroidIndex1 * centroidLength, codebooks[m], centroidIndex2 * centroidLength, centroidLength);
588+
}
589+
return sum;
590+
591+
}
592+
593+
@Override
594+
public float pqScoreDotProduct(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> encodedChunk, int encodedOffset, VectorFloat<?> centeredQuery, int subspaceCount) {
595+
float dp = 0;
596+
for (int m = 0; m < subspaceCount; m++) {
597+
int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset));
598+
int centroidLength = subvectorSizesAndOffsets[m][0];
599+
int centroidOffset = subvectorSizesAndOffsets[m][1];
600+
dp += dotProduct(codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength);
601+
}
602+
return dp;
603+
}
604+
605+
@Override
606+
public float pqScoreCosine(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> encodedChunk, int encodedOffset, VectorFloat<?> centeredQuery, int subspaceCount) {
607+
float sum = 0;
608+
float aMagnitude = 0;
609+
float bMagnitude = 0;
610+
611+
for (int m = 0; m < subspaceCount; m++) {
612+
int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset));
613+
int centroidLength = subvectorSizesAndOffsets[m][0];
614+
int centroidOffset = subvectorSizesAndOffsets[m][1];
615+
var codebookOffset = centroidIndex * centroidLength;
616+
sum += dotProduct(codebooks[m], codebookOffset, centeredQuery, centroidOffset, centroidLength);
617+
aMagnitude += dotProduct(codebooks[m], codebookOffset, codebooks[m], codebookOffset, centroidLength);
618+
bMagnitude += dotProduct(centeredQuery, centroidOffset, centeredQuery, centroidOffset, centroidLength);
619+
}
620+
float cosine = sum / (float) Math.sqrt(aMagnitude * bMagnitude);
621+
return cosine;
622+
}
623+
624+
@Override
625+
public float pqScoreEuclidean(VectorFloat<?>[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence<?> encodedChunk, int encodedOffset, VectorFloat<?> centeredQuery, int subspaceCount) {
626+
float sum = 0;
627+
for (int m = 0; m < subspaceCount; m++) {
628+
int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset));
629+
int centroidLength = subvectorSizesAndOffsets[m][0];
630+
int centroidOffset = subvectorSizesAndOffsets[m][1];
631+
sum += squareDistance(codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength);
632+
}
633+
return sum;
634+
}
635+
550636
}

0 commit comments

Comments
 (0)