Skip to content

Commit 11a6df1

Browse files
author
elasticsearchmachine
committed
Deduplicate attributes
1 parent 0f12e78 commit 11a6df1

2 files changed

Lines changed: 104 additions & 11 deletions

File tree

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/PushDownVectorSimilarityFunctions.java

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject;
2828

2929
import java.util.ArrayList;
30+
import java.util.HashMap;
3031
import java.util.List;
32+
import java.util.Map;
3133

3234
import static org.elasticsearch.xpack.esql.core.expression.Attribute.rawTemporaryName;
3335

@@ -47,7 +49,7 @@ public PushDownVectorSimilarityFunctions() {
4749
@Override
4850
protected LogicalPlan rule(LogicalPlan plan, LocalLogicalOptimizerContext context) {
4951
if (plan instanceof Eval || plan instanceof Filter || plan instanceof Aggregate) {
50-
AttributeSet.Builder addedAttrs = AttributeSet.builder();
52+
Map<Attribute.IdIgnoringWrapper, Attribute> addedAttrs = new HashMap<>();
5153
LogicalPlan transformedPlan = plan.transformExpressionsOnly(
5254
VectorSimilarityFunction.class,
5355
similarityFunction -> replaceFieldsForFieldTransformations(similarityFunction, addedAttrs, context)
@@ -59,15 +61,19 @@ protected LogicalPlan rule(LogicalPlan plan, LocalLogicalOptimizerContext contex
5961

6062
List<Attribute> previousAttrs = transformedPlan.output();
6163
// Transforms EsRelation to extract the new attribute
62-
AttributeSet attrSet = addedAttrs.build();
64+
65+
List<Attribute> addedAttrsList = addedAttrs.values().stream().toList();
6366
transformedPlan = transformedPlan.transformDown(
6467
EsRelation.class,
65-
esRelation -> esRelation.withAttributes(attrSet.combine(esRelation.outputSet()).stream().toList())
68+
esRelation -> {
69+
AttributeSet updatedOutput = esRelation.outputSet().combine(AttributeSet.of(addedAttrsList));
70+
return esRelation.withAttributes(updatedOutput.stream().toList());
71+
}
6672
);
6773
// Transforms Projects so the new attribute is not discarded
6874
transformedPlan = transformedPlan.transformDown(EsqlProject.class, esProject -> {
6975
List<NamedExpression> projections = new ArrayList<>(esProject.projections());
70-
projections.addAll(attrSet.stream().toList());
76+
projections.addAll(addedAttrsList);
7177
return esProject.withProjections(projections);
7278
});
7379

@@ -79,7 +85,7 @@ protected LogicalPlan rule(LogicalPlan plan, LocalLogicalOptimizerContext contex
7985

8086
private static Expression replaceFieldsForFieldTransformations(
8187
VectorSimilarityFunction similarityFunction,
82-
AttributeSet.Builder addedAttrs,
88+
Map<Attribute.IdIgnoringWrapper, Attribute> addedAttrs,
8389
LocalLogicalOptimizerContext context
8490
) {
8591
// Only replace if exactly one side is a literal and the other a field attribute
@@ -102,8 +108,10 @@ private static Expression replaceFieldsForFieldTransformations(
102108
@SuppressWarnings("unchecked")
103109
List<Number> vectorList = (List<Number>) literal.value();
104110
float[] vectorArray = new float[vectorList.size()];
111+
int arrayHashCode = 0;
105112
for (int i = 0; i < vectorList.size(); i++) {
106113
vectorArray[i] = vectorList.get(i).floatValue();
114+
arrayHashCode = 31 * arrayHashCode + Float.floatToIntBits(vectorArray[i]);
107115
}
108116

109117
// Change the similarity function to a reference of a transformation on the field
@@ -112,19 +120,24 @@ private static Expression replaceFieldsForFieldTransformations(
112120
similarityFunction.dataType(),
113121
similarityFunction.getBlockLoaderFunctionConfig()
114122
);
115-
var nameId = new NameId();
116-
var name = rawTemporaryName(fieldAttr.name(), similarityFunction.nodeName(), nameId.toString());
117-
var functionAttr = new FieldAttribute(
123+
var name = rawTemporaryName(fieldAttr.name(), similarityFunction.nodeName(), String.valueOf(arrayHashCode));
124+
// TODO: Check if exists before adding, retrieve the previous one
125+
var newFunctionAttr = new FieldAttribute(
118126
fieldAttr.source(),
119127
fieldAttr.parentName(),
120128
fieldAttr.qualifier(),
121129
name,
122130
functionEsField,
123131
fieldAttr.nullable(),
124-
nameId,
132+
new NameId(),
125133
true
126134
);
127-
addedAttrs.add(functionAttr);
128-
return functionAttr;
135+
Attribute.IdIgnoringWrapper key = newFunctionAttr.ignoreId();
136+
if (addedAttrs.containsKey(key)) {;
137+
return addedAttrs.get(key);
138+
}
139+
140+
addedAttrs.put(key, newFunctionAttr);
141+
return newFunctionAttr;
129142
}
130143
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
import org.elasticsearch.xpack.esql.expression.predicate.logical.And;
5050
import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull;
5151
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
52+
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
53+
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;
5254
import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
5355
import org.elasticsearch.xpack.esql.index.EsIndex;
5456
import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules;
@@ -1417,6 +1419,84 @@ public void testVectorFunctionsUpdateIntermediateProjections() {
14171419
assertTrue(esRelation.output().contains(fieldAttr));
14181420
}
14191421

1422+
public void testVectorFunctionsWithDuplicateFunctions() {
1423+
assumeTrue("requires similarity functions", EsqlCapabilities.Cap.VECTOR_SIMILARITY_FUNCTIONS_PUSHDOWN.isEnabled());
1424+
String query = """
1425+
from test_all
1426+
| eval s1 = v_dot_product(dense_vector, [1.0, 2.0, 3.0]), s2 = v_dot_product(dense_vector, [1.0, 2.0, 3.0]) * 2 / 3
1427+
| eval s3 = v_dot_product(dense_vector, [1.0, 2.0, 3.0]) + 5, r1 = v_dot_product(dense_vector, [4.0, 5.0, 6.0])
1428+
| eval r2 = v_dot_product(dense_vector, [4.0, 5.0, 6.0]) + v_cosine(dense_vector, [4.0, 5.0, 6.0])
1429+
| keep s1, s2, r1, r2
1430+
""";
1431+
1432+
LogicalPlan plan = localPlan(plan(query, allTypesAnalyzer), TEST_SEARCH_STATS);
1433+
1434+
// EsqlProject[[s1{r}#5, s2{r}#8, r1{r}#14, r2{r}#18]]
1435+
var project = as(plan, EsqlProject.class);
1436+
assertThat(Expressions.names(project.projections()), contains("s1", "s2", "r1", "r2"));
1437+
1438+
// Eval with s1, s2, r1, r2
1439+
var eval = as(project.child(), Eval.class);
1440+
assertThat(eval.fields(), hasSize(4));
1441+
1442+
// Check s1 = $$dense_vector$DotProduct$...
1443+
var s1Alias = as(eval.fields().getFirst(), Alias.class);
1444+
assertThat(s1Alias.name(), equalTo("s1"));
1445+
var s1FieldAttr = as(s1Alias.child(), FieldAttribute.class);
1446+
assertThat(s1FieldAttr.fieldName().string(), equalTo("dense_vector"));
1447+
assertThat(s1FieldAttr.name(), startsWith("$$dense_vector$DotProduct"));
1448+
var s1Field = as(s1FieldAttr.field(), FunctionEsField.class);
1449+
var s1Config = as(s1Field.functionConfig(), DenseVectorFieldMapper.VectorSimilarityFunctionConfig.class);
1450+
assertThat(s1Config.similarityFunction(), is(DotProduct.SIMILARITY_FUNCTION));
1451+
assertThat(s1Config.vector(), equalTo(new float[] { 1.0f, 2.0f, 3.0f }));
1452+
1453+
// Check s2 = $$dense_vector$DotProduct$1606418432 * 2 / 3 (same field as s1)
1454+
var s2Alias = as(eval.fields().get(1), Alias.class);
1455+
assertThat(s2Alias.name(), equalTo("s2"));
1456+
var s2Div = as(s2Alias.child(), Div.class);
1457+
var s2Mul = as(s2Div.left(), Mul.class);
1458+
var s2FieldAttr = as(s2Mul.left(), FieldAttribute.class);
1459+
assertThat(s1FieldAttr, is(s2FieldAttr));
1460+
1461+
// Check r1 = $$dense_vector$DotProduct$882900992 (vector [4.0, 5.0, 6.0])
1462+
var r1Alias = as(eval.fields().get(2), Alias.class);
1463+
assertThat(r1Alias.name(), equalTo("r1"));
1464+
var r1FieldAttr = as(r1Alias.child(), FieldAttribute.class);
1465+
assertThat(r1FieldAttr.fieldName().string(), equalTo("dense_vector"));
1466+
assertThat(r1FieldAttr.name(), startsWith("$$dense_vector$DotProduct"));
1467+
var r1Field = as(r1FieldAttr.field(), FunctionEsField.class);
1468+
var r1Config = as(r1Field.functionConfig(), DenseVectorFieldMapper.VectorSimilarityFunctionConfig.class);
1469+
assertThat(r1Config.similarityFunction(), is(DotProduct.SIMILARITY_FUNCTION));
1470+
assertThat(r1Config.vector(), equalTo(new float[] { 4.0f, 5.0f, 6.0f }));
1471+
1472+
// Check r2 = $$dense_vector$DotProduct$882900992 + $$dense_vector$CosineSimilarity$882900992
1473+
var r2Alias = as(eval.fields().get(3), Alias.class);
1474+
assertThat(r2Alias.name(), equalTo("r2"));
1475+
var r2Add = as(r2Alias.child(), Add.class);
1476+
1477+
// Left side: DotProduct field (same as r1)
1478+
var r2DotProductFieldAttr = as(r2Add.left(), FieldAttribute.class);
1479+
assertThat(r2DotProductFieldAttr, is(r1FieldAttr));
1480+
1481+
// Right side: CosineSimilarity field
1482+
var r2CosineFieldAttr = as(r2Add.right(), FieldAttribute.class);
1483+
assertThat(r2CosineFieldAttr.fieldName().string(), equalTo("dense_vector"));
1484+
assertThat(r2CosineFieldAttr.name(), startsWith("$$dense_vector$CosineSimilarity"));
1485+
var r2CosineField = as(r2CosineFieldAttr.field(), FunctionEsField.class);
1486+
var r2CosineConfig = as(r2CosineField.functionConfig(), DenseVectorFieldMapper.VectorSimilarityFunctionConfig.class);
1487+
assertThat(r2CosineConfig.similarityFunction(), is(CosineSimilarity.SIMILARITY_FUNCTION));
1488+
assertThat(r2CosineConfig.vector(), equalTo(new float[] { 4.0f, 5.0f, 6.0f }));
1489+
1490+
// Limit[1000[INTEGER],false,false]
1491+
var limit = as(eval.child(), Limit.class);
1492+
1493+
// EsRelation[test_all][!alias_integer, boolean{f}#24, byte{f}#25, constant..]
1494+
var esRelation = as(limit.child(), EsRelation.class);
1495+
assertTrue(esRelation.output().contains(s1FieldAttr));
1496+
assertTrue(esRelation.output().contains(r1FieldAttr));
1497+
assertTrue(esRelation.output().contains(r2CosineFieldAttr));
1498+
}
1499+
14201500
private IsNotNull isNotNull(Expression field) {
14211501
return new IsNotNull(EMPTY, field);
14221502
}

0 commit comments

Comments
 (0)