Skip to content

Commit 7942f3e

Browse files
authored
Fix dim validation for bit element_type (#114533)
A silly bug has reared its ugly head. Apparently, our dimension validations are predicated on JSON parsing order, that is not good. So, this commit adjusts the dim validations so that it is an actual validation, instead of something that occurs during parsing. Additionally, I found that our custom formats were not overriding `getMaxDimensions` correctly. Typically, and in production, this isn't that big of a deal, but I have found it useful to do this for other testing purposes (so that we don't have to rely on the perfield codec for more direct and advanced testing).
1 parent 4f4b91d commit 7942f3e

10 files changed

Lines changed: 75 additions & 38 deletions

File tree

docs/changelog/114533.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 114533
2+
summary: Fix dim validation for bit `element_type`
3+
area: Vector Search
4+
type: bug
5+
issues: []

server/src/main/java/org/elasticsearch/index/codec/vectors/ES813FlatVectorFormat.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232

3333
import java.io.IOException;
3434

35+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;
36+
3537
public class ES813FlatVectorFormat extends KnnVectorsFormat {
3638

3739
static final String NAME = "ES813FlatVectorFormat";
@@ -55,6 +57,11 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException
5557
return new ES813FlatVectorReader(format.fieldsReader(state));
5658
}
5759

60+
@Override
61+
public int getMaxDimensions(String fieldName) {
62+
return MAX_DIMS_COUNT;
63+
}
64+
5865
static class ES813FlatVectorWriter extends KnnVectorsWriter {
5966

6067
private final FlatVectorsWriter writer;

server/src/main/java/org/elasticsearch/index/codec/vectors/ES813Int8FlatVectorFormat.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131
import java.io.IOException;
3232

33+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;
34+
3335
public class ES813Int8FlatVectorFormat extends KnnVectorsFormat {
3436

3537
static final String NAME = "ES813Int8FlatVectorFormat";
@@ -58,6 +60,11 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException
5860
return new ES813FlatVectorReader(format.fieldsReader(state));
5961
}
6062

63+
@Override
64+
public int getMaxDimensions(String fieldName) {
65+
return MAX_DIMS_COUNT;
66+
}
67+
6168
@Override
6269
public String toString() {
6370
return NAME + "(name=" + NAME + ", innerFormat=" + format + ")";

server/src/main/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormat.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
2424
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
25+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;
2526

2627
public final class ES814HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat {
2728

@@ -70,7 +71,7 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException
7071

7172
@Override
7273
public int getMaxDimensions(String fieldName) {
73-
return 1024;
74+
return MAX_DIMS_COUNT;
7475
}
7576

7677
@Override

server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorFormat.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
import java.io.IOException;
2020

21+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;
22+
2123
public class ES815BitFlatVectorFormat extends KnnVectorsFormat {
2224

2325
static final String NAME = "ES815BitFlatVectorFormat";
@@ -45,4 +47,9 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException
4547
public String toString() {
4648
return NAME;
4749
}
50+
51+
@Override
52+
public int getMaxDimensions(String fieldName) {
53+
return MAX_DIMS_COUNT;
54+
}
4855
}

server/src/main/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormat.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
import java.io.IOException;
2222

23+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;
24+
2325
public class ES815HnswBitVectorsFormat extends KnnVectorsFormat {
2426

2527
static final String NAME = "ES815HnswBitVectorsFormat";
@@ -72,4 +74,9 @@ public String toString() {
7274
+ flatVectorsFormat
7375
+ ")";
7476
}
77+
78+
@Override
79+
public int getMaxDimensions(String fieldName) {
80+
return MAX_DIMS_COUNT;
81+
}
7582
}

server/src/main/java/org/elasticsearch/index/codec/vectors/ES816BinaryQuantizedVectorsFormat.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
import java.io.IOException;
3131

32+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;
33+
3234
/**
3335
* Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10
3436
*/
@@ -68,6 +70,11 @@ public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException
6870
return new ES816BinaryQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state), scorer);
6971
}
7072

73+
@Override
74+
public int getMaxDimensions(String fieldName) {
75+
return MAX_DIMS_COUNT;
76+
}
77+
7178
@Override
7279
public String toString() {
7380
return "ES816BinaryQuantizedVectorsFormat(name=" + NAME + ", flatVectorScorer=" + scorer + ")";

server/src/main/java/org/elasticsearch/index/codec/vectors/ES816HnswBinaryQuantizedVectorsFormat.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER;
4040
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH;
4141
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN;
42+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;
4243

4344
/**
4445
* Copied from Lucene, replace with Lucene's implementation sometime after Lucene 10
@@ -128,7 +129,7 @@ public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException
128129

129130
@Override
130131
public int getMaxDimensions(String fieldName) {
131-
return 1024;
132+
return MAX_DIMS_COUNT;
132133
}
133134

134135
@Override

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -139,32 +139,27 @@ public static class Builder extends FieldMapper.Builder {
139139
if (o instanceof Integer == false) {
140140
throw new MapperParsingException("Property [dims] on field [" + n + "] must be an integer but got [" + o + "]");
141141
}
142-
int dims = XContentMapValues.nodeIntegerValue(o);
143-
int maxDims = elementType.getValue() == ElementType.BIT ? MAX_DIMS_COUNT_BIT : MAX_DIMS_COUNT;
144-
int minDims = elementType.getValue() == ElementType.BIT ? Byte.SIZE : 1;
145-
if (dims < minDims || dims > maxDims) {
146-
throw new MapperParsingException(
147-
"The number of dimensions for field ["
148-
+ n
149-
+ "] should be in the range ["
150-
+ minDims
151-
+ ", "
152-
+ maxDims
153-
+ "] but was ["
154-
+ dims
155-
+ "]"
156-
);
157-
}
158-
if (elementType.getValue() == ElementType.BIT) {
159-
if (dims % Byte.SIZE != 0) {
142+
143+
return XContentMapValues.nodeIntegerValue(o);
144+
}, m -> toType(m).fieldType().dims, XContentBuilder::field, Object::toString).setSerializerCheck((id, ic, v) -> v != null)
145+
.setMergeValidator((previous, current, c) -> previous == null || Objects.equals(previous, current))
146+
.addValidator(dims -> {
147+
if (dims == null) {
148+
return;
149+
}
150+
int maxDims = elementType.getValue() == ElementType.BIT ? MAX_DIMS_COUNT_BIT : MAX_DIMS_COUNT;
151+
int minDims = elementType.getValue() == ElementType.BIT ? Byte.SIZE : 1;
152+
if (dims < minDims || dims > maxDims) {
160153
throw new MapperParsingException(
161-
"The number of dimensions for field [" + n + "] should be a multiple of 8 but was [" + dims + "]"
154+
"The number of dimensions should be in the range [" + minDims + ", " + maxDims + "] but was [" + dims + "]"
162155
);
163156
}
164-
}
165-
return dims;
166-
}, m -> toType(m).fieldType().dims, XContentBuilder::field, Object::toString).setSerializerCheck((id, ic, v) -> v != null)
167-
.setMergeValidator((previous, current, c) -> previous == null || Objects.equals(previous, current));
157+
if (elementType.getValue() == ElementType.BIT) {
158+
if (dims % Byte.SIZE != 0) {
159+
throw new MapperParsingException("The number of dimensions for should be a multiple of 8 but was [" + dims + "]");
160+
}
161+
}
162+
});
168163
private final Parameter<VectorSimilarity> similarity;
169164

170165
private final Parameter<IndexOptions> indexOptions;

server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ protected void registerParameters(ParameterChecker checker) throws IOException {
175175
),
176176
fieldMapping(
177177
b -> b.field("type", "dense_vector")
178-
.field("dims", dims)
178+
.field("dims", dims * 8)
179179
.field("index", true)
180180
.field("similarity", "l2_norm")
181181
.field("element_type", "bit")
@@ -192,7 +192,7 @@ protected void registerParameters(ParameterChecker checker) throws IOException {
192192
),
193193
fieldMapping(
194194
b -> b.field("type", "dense_vector")
195-
.field("dims", dims)
195+
.field("dims", dims * 8)
196196
.field("index", true)
197197
.field("similarity", "l2_norm")
198198
.field("element_type", "bit")
@@ -891,9 +891,7 @@ public void testDims() {
891891
})));
892892
assertThat(
893893
e.getMessage(),
894-
equalTo(
895-
"Failed to parse mapping: " + "The number of dimensions for field [field] should be in the range [1, 4096] but was [0]"
896-
)
894+
equalTo("Failed to parse mapping: " + "The number of dimensions should be in the range [1, 4096] but was [0]")
897895
);
898896
}
899897
// test max limit for non-indexed vectors
@@ -904,10 +902,7 @@ public void testDims() {
904902
})));
905903
assertThat(
906904
e.getMessage(),
907-
equalTo(
908-
"Failed to parse mapping: "
909-
+ "The number of dimensions for field [field] should be in the range [1, 4096] but was [5000]"
910-
)
905+
equalTo("Failed to parse mapping: " + "The number of dimensions should be in the range [1, 4096] but was [5000]")
911906
);
912907
}
913908
// test max limit for indexed vectors
@@ -919,10 +914,7 @@ public void testDims() {
919914
})));
920915
assertThat(
921916
e.getMessage(),
922-
equalTo(
923-
"Failed to parse mapping: "
924-
+ "The number of dimensions for field [field] should be in the range [1, 4096] but was [5000]"
925-
)
917+
equalTo("Failed to parse mapping: " + "The number of dimensions should be in the range [1, 4096] but was [5000]")
926918
);
927919
}
928920
}
@@ -955,6 +947,14 @@ public void testMergeDims() throws IOException {
955947
);
956948
}
957949

950+
public void testLargeDimsBit() throws IOException {
951+
createMapperService(fieldMapping(b -> {
952+
b.field("type", "dense_vector");
953+
b.field("dims", 1024 * Byte.SIZE);
954+
b.field("element_type", ElementType.BIT.toString());
955+
}));
956+
}
957+
958958
public void testDefaults() throws Exception {
959959
DocumentMapper mapper = createDocumentMapper(fieldMapping(b -> b.field("type", "dense_vector").field("dims", 3)));
960960

0 commit comments

Comments
 (0)