Skip to content

Commit 165d46c

Browse files
authored
Merge branch 'main' into swiss-table-proto
2 parents 62c95c2 + ebc04a6 commit 165d46c

30 files changed

Lines changed: 249 additions & 493 deletions

File tree

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerFloat32OperationBenchmark.java

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ public class VectorScorerFloat32OperationBenchmark {
6464
@Param({ "1", "128", "207", "256", "300", "512", "702", "1024", "1536", "2048" })
6565
public int size;
6666

67-
@Param({ "COSINE", "DOT_PRODUCT", "EUCLIDEAN" })
67+
@Param({ "DOT_PRODUCT", "EUCLIDEAN" })
6868
public VectorSimilarityType function;
6969

7070
@FunctionalInterface
@@ -101,13 +101,11 @@ public void init() {
101101
MemorySegment.copy(MemorySegment.ofArray(floatsB), LAYOUT_LE_FLOAT, 0L, nativeSegB, LAYOUT_LE_FLOAT, 0L, floatsB.length);
102102

103103
luceneImpl = switch (function) {
104-
case COSINE -> VectorUtil::cosine;
105104
case DOT_PRODUCT -> VectorUtil::dotProduct;
106105
case EUCLIDEAN -> VectorUtil::squareDistance;
107106
default -> throw new UnsupportedOperationException("Not used");
108107
};
109108
nativeImpl = switch (function) {
110-
case COSINE -> VectorScorerFloat32OperationBenchmark::cosineFloat32;
111109
case DOT_PRODUCT -> VectorScorerFloat32OperationBenchmark::dotProductFloat32;
112110
case EUCLIDEAN -> VectorScorerFloat32OperationBenchmark::squareDistanceFloat32;
113111
default -> throw new UnsupportedOperationException("Not used");
@@ -147,14 +145,6 @@ static VectorSimilarityFunctions vectorSimilarityFunctions() {
147145
return NativeAccess.instance().getVectorSimilarityFunctions().get();
148146
}
149147

150-
static float cosineFloat32(MemorySegment a, MemorySegment b, int length) {
151-
try {
152-
return (float) vectorSimilarityFunctions.cosineHandleFloat32().invokeExact(a, b, length);
153-
} catch (Throwable e) {
154-
throw rethrow(e);
155-
}
156-
}
157-
158148
static float dotProductFloat32(MemorySegment a, MemorySegment b, int length) {
159149
try {
160150
return (float) vectorSimilarityFunctions.dotProductHandleFloat32().invokeExact(a, b, length);

benchmarks/src/test/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerFloat32OperationBenchmarkTests.java

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,9 @@ public void test() {
4949
bench.init();
5050
try {
5151
float expected = switch (function) {
52-
case COSINE -> cosineFloat32Scalar(bench.floatsA, bench.floatsB);
5352
case DOT_PRODUCT -> dotProductFloat32Scalar(bench.floatsA, bench.floatsB);
5453
case EUCLIDEAN -> squareDistanceFloat32Scalar(bench.floatsA, bench.floatsB);
55-
case MAXIMUM_INNER_PRODUCT -> throw new AssumptionViolatedException("Not tested");
54+
default -> throw new AssumptionViolatedException("Not tested");
5655
};
5756
assertEquals(expected, bench.lucene(), delta);
5857
assertEquals(expected, bench.luceneWithCopy(), delta);
@@ -81,20 +80,6 @@ public static Iterable<Object[]> parametersFactory() {
8180
}
8281
}
8382

84-
/** Computes the cosine of the given vectors a and b. */
85-
static float cosineFloat32Scalar(float[] a, float[] b) {
86-
float dot = 0, normA = 0, normB = 0;
87-
for (int i = 0; i < a.length; i++) {
88-
dot += a[i] * b[i];
89-
normA += a[i] * a[i];
90-
normB += b[i] * b[i];
91-
}
92-
double normAA = Math.sqrt(normA);
93-
double normBB = Math.sqrt(normB);
94-
if (normAA == 0.0f || normBB == 0.0f) return 0.0f;
95-
return (float) (dot / (normAA * normBB));
96-
}
97-
9883
/** Computes the dot product of the given vectors a and b. */
9984
static float dotProductFloat32Scalar(float[] a, float[] b) {
10085
float res = 0;

docs/changelog/139461.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 139461
2+
summary: Fix checks to define if a JOIN is remote
3+
area: ES|QL
4+
type: bug
5+
issues: []

libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -127,16 +127,6 @@ public interface VectorSimilarityFunctions {
127127
*/
128128
MethodHandle squareDistanceHandle7uBulkWithOffsets();
129129

130-
/**
131-
* Produces a method handle returning the cosine of float32 vectors.
132-
*
133-
* <p> The type of the method handle will have {@code float} as return type, The type of
134-
* its first and second arguments will be {@code MemorySegment}, whose contents is the
135-
* vector data floats. The third argument is the length of the vector data - number of
136-
* 4-byte float32 elements.
137-
*/
138-
MethodHandle cosineHandleFloat32();
139-
140130
/**
141131
* Produces a method handle returning the dot product of float32 vectors.
142132
*

libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java

Lines changed: 22 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ public final class JdkVectorLibrary implements VectorLibrary {
3737
static final MethodHandle sqr7u$mh;
3838
static final MethodHandle sqr7uBulk$mh;
3939
static final MethodHandle sqr7uBulkWithOffsets$mh;
40-
static final MethodHandle cosf32$mh;
4140
static final MethodHandle dotf32$mh;
4241
static final MethodHandle sqrf32$mh;
4342

@@ -51,99 +50,28 @@ public final class JdkVectorLibrary implements VectorLibrary {
5150
int caps = (int) vecCaps$mh.invokeExact();
5251
logger.info("vec_caps=" + caps);
5352
if (caps > 0) {
54-
if (caps == 2) {
55-
dot7u$mh = downcallHandle(
56-
"vec_dot7u_2",
57-
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
58-
LinkerHelperUtil.critical()
59-
);
60-
dot7uBulk$mh = downcallHandle(
61-
"vec_dot7u_bulk_2",
62-
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS),
63-
LinkerHelperUtil.critical()
64-
);
65-
dot7uBulkWithOffsets$mh = downcallHandle(
66-
"vec_dot7u_bulk_offsets_2",
67-
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS),
68-
LinkerHelperUtil.critical()
69-
);
70-
sqr7u$mh = downcallHandle(
71-
"vec_sqr7u_2",
72-
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
73-
LinkerHelperUtil.critical()
74-
);
75-
sqr7uBulk$mh = downcallHandle(
76-
"vec_sqr7u_bulk_2",
77-
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS),
78-
LinkerHelperUtil.critical()
79-
);
80-
sqr7uBulkWithOffsets$mh = downcallHandle(
81-
"vec_sqr7u_bulk_offsets_2",
82-
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS),
83-
LinkerHelperUtil.critical()
84-
);
85-
cosf32$mh = downcallHandle(
86-
"vec_cosf32_2",
87-
FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT),
88-
LinkerHelperUtil.critical()
89-
);
90-
dotf32$mh = downcallHandle(
91-
"vec_dotf32_2",
92-
FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT),
93-
LinkerHelperUtil.critical()
94-
);
95-
sqrf32$mh = downcallHandle(
96-
"vec_sqrf32_2",
97-
FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT),
98-
LinkerHelperUtil.critical()
99-
);
100-
} else {
101-
dot7u$mh = downcallHandle(
102-
"vec_dot7u",
103-
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
104-
LinkerHelperUtil.critical()
105-
);
106-
dot7uBulk$mh = downcallHandle(
107-
"vec_dot7u_bulk",
108-
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS),
109-
LinkerHelperUtil.critical()
110-
);
111-
dot7uBulkWithOffsets$mh = downcallHandle(
112-
"vec_dot7u_bulk_offsets",
113-
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS),
114-
LinkerHelperUtil.critical()
115-
);
116-
sqr7u$mh = downcallHandle(
117-
"vec_sqr7u",
118-
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
119-
LinkerHelperUtil.critical()
120-
);
121-
sqr7uBulk$mh = downcallHandle(
122-
"vec_sqr7u_bulk",
123-
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS),
124-
LinkerHelperUtil.critical()
125-
);
126-
sqr7uBulkWithOffsets$mh = downcallHandle(
127-
"vec_sqr7u_bulk_offsets",
128-
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS),
129-
LinkerHelperUtil.critical()
130-
);
131-
cosf32$mh = downcallHandle(
132-
"vec_cosf32",
133-
FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT),
134-
LinkerHelperUtil.critical()
135-
);
136-
dotf32$mh = downcallHandle(
137-
"vec_dotf32",
138-
FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT),
139-
LinkerHelperUtil.critical()
140-
);
141-
sqrf32$mh = downcallHandle(
142-
"vec_sqrf32",
143-
FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT),
144-
LinkerHelperUtil.critical()
145-
);
146-
}
53+
String suffix = caps == 2 ? "_2" : "";
54+
FunctionDescriptor intSingle = FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT);
55+
FunctionDescriptor floatSingle = FunctionDescriptor.of(JAVA_FLOAT, ADDRESS, ADDRESS, JAVA_INT);
56+
FunctionDescriptor bulk = FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS);
57+
FunctionDescriptor bulkOffsets = FunctionDescriptor.ofVoid(
58+
ADDRESS,
59+
ADDRESS,
60+
JAVA_INT,
61+
JAVA_INT,
62+
ADDRESS,
63+
JAVA_INT,
64+
ADDRESS
65+
);
66+
67+
dot7u$mh = downcallHandle("vec_dot7u" + suffix, intSingle, LinkerHelperUtil.critical());
68+
dot7uBulk$mh = downcallHandle("vec_dot7u_bulk" + suffix, bulk, LinkerHelperUtil.critical());
69+
dot7uBulkWithOffsets$mh = downcallHandle("vec_dot7u_bulk_offsets" + suffix, bulkOffsets, LinkerHelperUtil.critical());
70+
sqr7u$mh = downcallHandle("vec_sqr7u" + suffix, intSingle, LinkerHelperUtil.critical());
71+
sqr7uBulk$mh = downcallHandle("vec_sqr7u_bulk" + suffix, bulk, LinkerHelperUtil.critical());
72+
sqr7uBulkWithOffsets$mh = downcallHandle("vec_sqr7u_bulk_offsets" + suffix, bulkOffsets, LinkerHelperUtil.critical());
73+
dotf32$mh = downcallHandle("vec_dotf32" + suffix, floatSingle, LinkerHelperUtil.critical());
74+
sqrf32$mh = downcallHandle("vec_sqrf32" + suffix, floatSingle, LinkerHelperUtil.critical());
14775
INSTANCE = new JdkVectorSimilarityFunctions();
14876
} else {
14977
if (caps < 0) {
@@ -157,7 +85,6 @@ public final class JdkVectorLibrary implements VectorLibrary {
15785
sqr7u$mh = null;
15886
sqr7uBulk$mh = null;
15987
sqr7uBulkWithOffsets$mh = null;
160-
cosf32$mh = null;
16188
dotf32$mh = null;
16289
sqrf32$mh = null;
16390
INSTANCE = null;
@@ -243,19 +170,6 @@ static void squareDistance7uBulkWithOffsets(
243170
sqr7uBulkWithOffsets(a, b, length, pitch, offsets, count, result);
244171
}
245172

246-
/**
247-
* Computes the cosine of given float32 vectors.
248-
*
249-
* @param a address of the first vector
250-
* @param b address of the second vector
251-
* @param elementCount the vector dimensions, number of float32 elements in the segment
252-
*/
253-
static float cosineF32(MemorySegment a, MemorySegment b, int elementCount) {
254-
checkByteSize(a, b);
255-
Objects.checkFromIndexSize(0, elementCount, (int) a.byteSize() / Float.BYTES);
256-
return cosf32(a, b, elementCount);
257-
}
258-
259173
/**
260174
* Computes the dot product of given float32 vectors.
261175
*
@@ -352,14 +266,6 @@ private static void sqr7uBulkWithOffsets(
352266
}
353267
}
354268

355-
private static float cosf32(MemorySegment a, MemorySegment b, int length) {
356-
try {
357-
return (float) JdkVectorLibrary.cosf32$mh.invokeExact(a, b, length);
358-
} catch (Throwable t) {
359-
throw new AssertionError(t);
360-
}
361-
}
362-
363269
private static float dotf32(MemorySegment a, MemorySegment b, int length) {
364270
try {
365271
return (float) JdkVectorLibrary.dotf32$mh.invokeExact(a, b, length);
@@ -382,7 +288,6 @@ private static float sqrf32(MemorySegment a, MemorySegment b, int length) {
382288
static final MethodHandle SQR_HANDLE_7U;
383289
static final MethodHandle SQR_HANDLE_7U_BULK;
384290
static final MethodHandle SQR_HANDLE_7U_BULK_WITH_OFFSETS;
385-
static final MethodHandle COS_HANDLE_FLOAT32;
386291
static final MethodHandle DOT_HANDLE_FLOAT32;
387292
static final MethodHandle SQR_HANDLE_FLOAT32;
388293

@@ -427,7 +332,6 @@ private static float sqrf32(MemorySegment a, MemorySegment b, int length) {
427332
);
428333

429334
MethodType singleFloatScorer = MethodType.methodType(float.class, MemorySegment.class, MemorySegment.class, int.class);
430-
COS_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "cosineF32", singleFloatScorer);
431335
DOT_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProductF32", singleFloatScorer);
432336
SQR_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "squareDistanceF32", singleFloatScorer);
433337
} catch (NoSuchMethodException | IllegalAccessException e) {
@@ -465,11 +369,6 @@ public MethodHandle squareDistanceHandle7uBulkWithOffsets() {
465369
return SQR_HANDLE_7U_BULK_WITH_OFFSETS;
466370
}
467371

468-
@Override
469-
public MethodHandle cosineHandleFloat32() {
470-
return COS_HANDLE_FLOAT32;
471-
}
472-
473372
@Override
474373
public MethodHandle dotProductHandleFloat32() {
475374
return DOT_HANDLE_FLOAT32;

libs/native/src/test/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctionsTests.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
package org.elasticsearch.nativeaccess;
1111

12+
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
13+
1214
import org.elasticsearch.common.logging.LogConfigurator;
1315
import org.elasticsearch.common.logging.NodeNamePatternConverter;
1416
import org.elasticsearch.test.ESTestCase;
@@ -34,7 +36,6 @@ public abstract class VectorSimilarityFunctionsTests extends ESTestCase {
3436
public static final Class<IndexOutOfBoundsException> IOOBE = IndexOutOfBoundsException.class;
3537

3638
public enum SimilarityFunction {
37-
COSINE,
3839
DOT_PRODUCT,
3940
SQUARE_DISTANCE
4041
}
@@ -45,10 +46,14 @@ public enum SimilarityFunction {
4546
protected final int size;
4647
protected final Optional<VectorSimilarityFunctions> vectorSimilarityFunctions;
4748

48-
protected static Stream<Object[]> allParameters() {
49+
@ParametersFactory
50+
public static Iterable<Object[]> parametersFactory() {
4951
var dims1 = Arrays.stream(new int[] { 1, 2, 4, 6, 8, 12, 13, 16, 25, 31, 32, 33, 64, 100, 128, 207, 256, 300, 512, 702, 768 });
5052
var dims2 = Arrays.stream(new int[] { 1000, 1023, 1024, 1025, 2047, 2048, 2049, 4095, 4096, 4097 });
51-
return IntStream.concat(dims1, dims2).boxed().flatMap(i -> Stream.of(SimilarityFunction.values()).map(f -> new Object[] { f, i }));
53+
return () -> IntStream.concat(dims1, dims2)
54+
.boxed()
55+
.flatMap(i -> Stream.of(SimilarityFunction.values()).map(f -> new Object[] { f, i }))
56+
.iterator();
5257
}
5358

5459
protected VectorSimilarityFunctionsTests(SimilarityFunction function, int size) {

libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryFloat32Tests.java

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99

1010
package org.elasticsearch.nativeaccess.jdk;
1111

12-
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
13-
1412
import org.elasticsearch.nativeaccess.VectorSimilarityFunctionsTests;
1513
import org.junit.AfterClass;
1614
import org.junit.BeforeClass;
@@ -44,11 +42,6 @@ public static void afterClass() {
4442
VectorSimilarityFunctionsTests.cleanup();
4543
}
4644

47-
@ParametersFactory
48-
public static Iterable<Object[]> parametersFactory() {
49-
return () -> VectorSimilarityFunctionsTests.allParameters().iterator();
50-
}
51-
5245
public void testAllZeroValues() {
5346
testFloat32Impl(float[]::new);
5447
}
@@ -113,7 +106,6 @@ static float[] randomFloatArray(int length) {
113106
float similarity(MemorySegment a, MemorySegment b, int length) {
114107
try {
115108
return switch (function) {
116-
case COSINE -> (float) getVectorDistance().cosineHandleFloat32().invokeExact(a, b, length);
117109
case DOT_PRODUCT -> (float) getVectorDistance().dotProductHandleFloat32().invokeExact(a, b, length);
118110
case SQUARE_DISTANCE -> (float) getVectorDistance().squareDistanceHandleFloat32().invokeExact(a, b, length);
119111
};
@@ -124,28 +116,11 @@ float similarity(MemorySegment a, MemorySegment b, int length) {
124116

125117
float scalarSimilarity(float[] a, float[] b) {
126118
return switch (function) {
127-
case COSINE -> cosineFloat32Scalar(a, b);
128119
case DOT_PRODUCT -> dotProductFloat32Scalar(a, b);
129120
case SQUARE_DISTANCE -> squareDistanceFloat32Scalar(a, b);
130121
};
131122
}
132123

133-
/** Computes the cosine of the given vectors a and b. */
134-
static float cosineFloat32Scalar(float[] a, float[] b) {
135-
float dot = 0, normA = 0, normB = 0;
136-
for (int i = 0; i < a.length; i++) {
137-
dot += a[i] * b[i];
138-
normA += a[i] * a[i];
139-
normB += b[i] * b[i];
140-
}
141-
double normAA = Math.sqrt(normA);
142-
double normBB = Math.sqrt(normB);
143-
if (normAA == 0.0f || normBB == 0.0f) {
144-
return 0.0f;
145-
}
146-
return (float) (dot / (normAA * normBB));
147-
}
148-
149124
/** Computes the dot product of the given vectors a and b. */
150125
static float dotProductFloat32Scalar(float[] a, float[] b) {
151126
float res = 0;

0 commit comments

Comments
 (0)