Skip to content

Commit c3da5be

Browse files
committed
Remove validate method from QueryVectorBuilder interface
1 parent 13aa7a1 commit c3da5be

6 files changed

Lines changed: 40 additions & 52 deletions

File tree

server/src/main/java/org/elasticsearch/search/vectors/QueryVectorBuilder.java

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
public interface QueryVectorBuilder extends VersionedNamedWriteable, ToXContentObject {
2121

2222
/**
23-
* Method for building a vector via the client. This method is called during RerwiteAndFetch.
23+
* Method for building a vector via the client. This method is called during RewriteAndFetch.
2424
* Typical implementation for this method will:
2525
* 1. call some asynchronous client action
2626
* 2. Handle failure/success for that action (usually passing failure to the provided listener)
@@ -31,15 +31,4 @@ public interface QueryVectorBuilder extends VersionedNamedWriteable, ToXContentO
3131
* @param listener listener to accept the created vector
3232
*/
3333
void buildVector(Client client, ActionListener<float[]> listener);
34-
35-
void validate();
36-
37-
default boolean isValid() {
38-
try {
39-
validate();
40-
return true;
41-
} catch (Exception e) {
42-
return false;
43-
}
44-
}
4534
}

server/src/test/java/org/elasticsearch/search/vectors/TestQueryVectorBuilderPlugin.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,6 @@ public void buildVector(Client client, ActionListener<float[]> listener) {
9292
listener.onResponse(response);
9393
}
9494

95-
@Override
96-
public void validate() {}
97-
9895
@Override
9996
public boolean equals(Object o) {
10097
if (this == o) return true;

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/vectors/TextEmbeddingQueryVectorBuilder.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
100100

101101
@Override
102102
public void buildVector(Client client, ActionListener<float[]> listener) {
103-
validate();
103+
if (modelId == null) {
104+
listener.onFailure(new IllegalArgumentException("[model_id] must not be null."));
105+
return;
106+
}
104107

105108
CoordinatedInferenceAction.Request inferRequest = CoordinatedInferenceAction.Request.forTextInput(
106109
modelId,
@@ -137,13 +140,6 @@ public void buildVector(Client client, ActionListener<float[]> listener) {
137140
}, listener::onFailure));
138141
}
139142

140-
@Override
141-
public void validate() {
142-
if (modelId == null) {
143-
throw new IllegalArgumentException("[model_id] must not be null.");
144-
}
145-
}
146-
147143
public String getModelText() {
148144
return modelText;
149145
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/InterceptedInferenceKnnVectorQueryBuilder.java

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,26 @@ private static InterceptedInferenceKnnVectorQueryBuilder rewriteQueryVectorBuild
196196
}
197197

198198
QueryVectorBuilder queryVectorBuilder = originalQuery.queryVectorBuilder();
199-
if (queryVectorBuilder != null && queryVectorBuilder.isValid()) {
200-
SetOnce<float[]> newQueryVectorSupplier = new SetOnce<>();
201-
QueryVectorBuilderAsyncAction.registerAction(queryRewriteContext, queryVectorBuilder, newQueryVectorSupplier);
202-
return new InterceptedInferenceKnnVectorQueryBuilder(queryBuilder, originalQuery, newQueryVectorSupplier);
199+
if (queryVectorBuilder != null) {
200+
boolean registerAction = false;
201+
if (queryVectorBuilder instanceof TextEmbeddingQueryVectorBuilder tevb) {
202+
// TextEmbeddingQueryVectorBuilder is a special case. If a model ID is set, we register an action to generate
203+
// the query vector. If not, the model text will be returned via getQuery() so that InferenceQueryUtils can
204+
// generate the appropriate inference results for the inferred inference ID(s).
205+
if (tevb.getModelId() != null) {
206+
registerAction = true;
207+
}
208+
} else {
209+
// We register an action to generate the query vector for all other query vector builders. If they cannot, buildVector()
210+
// should throw an error indicating why.
211+
registerAction = true;
212+
}
213+
214+
if (registerAction) {
215+
SetOnce<float[]> newQueryVectorSupplier = new SetOnce<>();
216+
QueryVectorBuilderAsyncAction.registerAction(queryRewriteContext, queryVectorBuilder, newQueryVectorSupplier);
217+
return new InterceptedInferenceKnnVectorQueryBuilder(queryBuilder, originalQuery, newQueryVectorSupplier);
218+
}
203219
}
204220

205221
return queryBuilder;
@@ -224,8 +240,8 @@ protected boolean preInferenceCoordinatorNodeValidate(ResolvedIndices resolvedIn
224240

225241
// We can skip remote cluster inference info gathering if:
226242
// - Inference fields are queried locally, guaranteeing that the query will be intercepted
227-
// - A valid query vector builder or query vector is set. In either case, remote cluster inference results are not required.
228-
return inferenceFieldsQueried > 0 && (hasValidQueryVectorBuilder() || originalQuery.queryVector() != null);
243+
// - A standalone query vector builder or query vector is set. In either case, remote cluster inference results are not required.
244+
return inferenceFieldsQueried > 0 && (hasStandaloneQueryVectorBuilder() || originalQuery.queryVector() != null);
229245
}
230246

231247
@Override
@@ -364,30 +380,26 @@ private MlDenseEmbeddingResults getTextEmbeddingResults(FullyQualifiedInferenceI
364380
return (MlDenseEmbeddingResults) inferenceResults;
365381
}
366382

367-
private void missingInferenceIdOverrideCheck() {
368-
QueryVectorBuilder queryVectorBuilder = originalQuery.queryVectorBuilder();
369-
if (queryVectorBuilder instanceof TextEmbeddingQueryVectorBuilder textEmbeddingQueryVectorBuilder
370-
&& textEmbeddingQueryVectorBuilder.getModelId() == null) {
371-
throw new IllegalArgumentException("[model_id] must not be null.");
372-
}
373-
}
374-
375383
private void validateQueryVectorBuilder(boolean requireExplicitInferenceId) {
376384
QueryVectorBuilder queryVectorBuilder = originalQuery.queryVectorBuilder();
377-
if (queryVectorBuilder instanceof TextEmbeddingQueryVectorBuilder tevb) {
378-
// TextEmbeddingQueryVectorBuilder only needs validation when an explicit inference ID is required. A non-null model text value
385+
if (queryVectorBuilder instanceof TextEmbeddingQueryVectorBuilder tevb && requireExplicitInferenceId) {
386+
// TextEmbeddingQueryVectorBuilder needs validation when an explicit inference ID is required. A non-null model text value
379387
// is guaranteed by its constructor.
380-
if (requireExplicitInferenceId) {
381-
tevb.validate();
388+
if (tevb.getModelId() == null) {
389+
throw new IllegalArgumentException("[model_id] must not be null.");
382390
}
383-
} else if (queryVectorBuilder != null) {
384-
// Other query vector builder types always require validation
385-
queryVectorBuilder.validate();
386391
}
392+
// For other query vector builders, we don't validate upfront. buildVector() will throw an error if it cannot generate a vector.
387393
}
388394

389-
private boolean hasValidQueryVectorBuilder() {
395+
private boolean hasStandaloneQueryVectorBuilder() {
390396
QueryVectorBuilder queryVectorBuilder = originalQuery.queryVectorBuilder();
391-
return queryVectorBuilder != null && queryVectorBuilder.isValid();
397+
if (queryVectorBuilder instanceof TextEmbeddingQueryVectorBuilder tevb) {
398+
// TextEmbeddingQueryVectorBuilder is considered to be a standalone query vector builder if the model ID is set
399+
return tevb.getModelId() != null;
400+
}
401+
402+
// All other query vector builders are assumed to be standalone
403+
return queryVectorBuilder != null;
392404
}
393405
}

x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/linear/LinearRetrieverIT.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -846,9 +846,6 @@ public void buildVector(Client client, ActionListener<float[]> listener) {
846846
listener.onResponse(vector);
847847
}
848848

849-
@Override
850-
public void validate() {}
851-
852849
@Override
853850
public String getWriteableName() {
854851
throw new IllegalStateException("Should not be called");

x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -885,9 +885,6 @@ public void buildVector(Client client, ActionListener<float[]> listener) {
885885
listener.onResponse(vector);
886886
}
887887

888-
@Override
889-
public void validate() {}
890-
891888
@Override
892889
public String getWriteableName() {
893890
throw new IllegalStateException("Should not be called");

0 commit comments

Comments
 (0)