|
49 | 49 | import org.elasticsearch.xpack.esql.expression.predicate.logical.And; |
50 | 50 | import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull; |
51 | 51 | import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; |
| 52 | +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div; |
| 53 | +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul; |
52 | 54 | import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; |
53 | 55 | import org.elasticsearch.xpack.esql.index.EsIndex; |
54 | 56 | import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules; |
@@ -1417,6 +1419,84 @@ public void testVectorFunctionsUpdateIntermediateProjections() { |
1417 | 1419 | assertTrue(esRelation.output().contains(fieldAttr)); |
1418 | 1420 | } |
1419 | 1421 |
|
| 1422 | + public void testVectorFunctionsWithDuplicateFunctions() { |
| 1423 | + assumeTrue("requires similarity functions", EsqlCapabilities.Cap.VECTOR_SIMILARITY_FUNCTIONS_PUSHDOWN.isEnabled()); |
| 1424 | + String query = """ |
| 1425 | + from test_all |
| 1426 | + | eval s1 = v_dot_product(dense_vector, [1.0, 2.0, 3.0]), s2 = v_dot_product(dense_vector, [1.0, 2.0, 3.0]) * 2 / 3 |
| 1427 | + | eval s3 = v_dot_product(dense_vector, [1.0, 2.0, 3.0]) + 5, r1 = v_dot_product(dense_vector, [4.0, 5.0, 6.0]) |
| 1428 | + | eval r2 = v_dot_product(dense_vector, [4.0, 5.0, 6.0]) + v_cosine(dense_vector, [4.0, 5.0, 6.0]) |
| 1429 | + | keep s1, s2, r1, r2 |
| 1430 | + """; |
| 1431 | + |
| 1432 | + LogicalPlan plan = localPlan(plan(query, allTypesAnalyzer), TEST_SEARCH_STATS); |
| 1433 | + |
| 1434 | + // EsqlProject[[s1{r}#5, s2{r}#8, r1{r}#14, r2{r}#18]] |
| 1435 | + var project = as(plan, EsqlProject.class); |
| 1436 | + assertThat(Expressions.names(project.projections()), contains("s1", "s2", "r1", "r2")); |
| 1437 | + |
| 1438 | + // Eval with s1, s2, r1, r2 |
| 1439 | + var eval = as(project.child(), Eval.class); |
| 1440 | + assertThat(eval.fields(), hasSize(4)); |
| 1441 | + |
| 1442 | + // Check s1 = $$dense_vector$DotProduct$... |
| 1443 | + var s1Alias = as(eval.fields().getFirst(), Alias.class); |
| 1444 | + assertThat(s1Alias.name(), equalTo("s1")); |
| 1445 | + var s1FieldAttr = as(s1Alias.child(), FieldAttribute.class); |
| 1446 | + assertThat(s1FieldAttr.fieldName().string(), equalTo("dense_vector")); |
| 1447 | + assertThat(s1FieldAttr.name(), startsWith("$$dense_vector$DotProduct")); |
| 1448 | + var s1Field = as(s1FieldAttr.field(), FunctionEsField.class); |
| 1449 | + var s1Config = as(s1Field.functionConfig(), DenseVectorFieldMapper.VectorSimilarityFunctionConfig.class); |
| 1450 | + assertThat(s1Config.similarityFunction(), is(DotProduct.SIMILARITY_FUNCTION)); |
| 1451 | + assertThat(s1Config.vector(), equalTo(new float[] { 1.0f, 2.0f, 3.0f })); |
| 1452 | + |
| 1453 | + // Check s2 = $$dense_vector$DotProduct$1606418432 * 2 / 3 (same field as s1) |
| 1454 | + var s2Alias = as(eval.fields().get(1), Alias.class); |
| 1455 | + assertThat(s2Alias.name(), equalTo("s2")); |
| 1456 | + var s2Div = as(s2Alias.child(), Div.class); |
| 1457 | + var s2Mul = as(s2Div.left(), Mul.class); |
| 1458 | + var s2FieldAttr = as(s2Mul.left(), FieldAttribute.class); |
| 1459 | + assertThat(s1FieldAttr, is(s2FieldAttr)); |
| 1460 | + |
| 1461 | + // Check r1 = $$dense_vector$DotProduct$882900992 (vector [4.0, 5.0, 6.0]) |
| 1462 | + var r1Alias = as(eval.fields().get(2), Alias.class); |
| 1463 | + assertThat(r1Alias.name(), equalTo("r1")); |
| 1464 | + var r1FieldAttr = as(r1Alias.child(), FieldAttribute.class); |
| 1465 | + assertThat(r1FieldAttr.fieldName().string(), equalTo("dense_vector")); |
| 1466 | + assertThat(r1FieldAttr.name(), startsWith("$$dense_vector$DotProduct")); |
| 1467 | + var r1Field = as(r1FieldAttr.field(), FunctionEsField.class); |
| 1468 | + var r1Config = as(r1Field.functionConfig(), DenseVectorFieldMapper.VectorSimilarityFunctionConfig.class); |
| 1469 | + assertThat(r1Config.similarityFunction(), is(DotProduct.SIMILARITY_FUNCTION)); |
| 1470 | + assertThat(r1Config.vector(), equalTo(new float[] { 4.0f, 5.0f, 6.0f })); |
| 1471 | + |
| 1472 | + // Check r2 = $$dense_vector$DotProduct$882900992 + $$dense_vector$CosineSimilarity$882900992 |
| 1473 | + var r2Alias = as(eval.fields().get(3), Alias.class); |
| 1474 | + assertThat(r2Alias.name(), equalTo("r2")); |
| 1475 | + var r2Add = as(r2Alias.child(), Add.class); |
| 1476 | + |
| 1477 | + // Left side: DotProduct field (same as r1) |
| 1478 | + var r2DotProductFieldAttr = as(r2Add.left(), FieldAttribute.class); |
| 1479 | + assertThat(r2DotProductFieldAttr, is(r1FieldAttr)); |
| 1480 | + |
| 1481 | + // Right side: CosineSimilarity field |
| 1482 | + var r2CosineFieldAttr = as(r2Add.right(), FieldAttribute.class); |
| 1483 | + assertThat(r2CosineFieldAttr.fieldName().string(), equalTo("dense_vector")); |
| 1484 | + assertThat(r2CosineFieldAttr.name(), startsWith("$$dense_vector$CosineSimilarity")); |
| 1485 | + var r2CosineField = as(r2CosineFieldAttr.field(), FunctionEsField.class); |
| 1486 | + var r2CosineConfig = as(r2CosineField.functionConfig(), DenseVectorFieldMapper.VectorSimilarityFunctionConfig.class); |
| 1487 | + assertThat(r2CosineConfig.similarityFunction(), is(CosineSimilarity.SIMILARITY_FUNCTION)); |
| 1488 | + assertThat(r2CosineConfig.vector(), equalTo(new float[] { 4.0f, 5.0f, 6.0f })); |
| 1489 | + |
| 1490 | + // Limit[1000[INTEGER],false,false] |
| 1491 | + var limit = as(eval.child(), Limit.class); |
| 1492 | + |
| 1493 | + // EsRelation[test_all][!alias_integer, boolean{f}#24, byte{f}#25, constant..] |
| 1494 | + var esRelation = as(limit.child(), EsRelation.class); |
| 1495 | + assertTrue(esRelation.output().contains(s1FieldAttr)); |
| 1496 | + assertTrue(esRelation.output().contains(r1FieldAttr)); |
| 1497 | + assertTrue(esRelation.output().contains(r2CosineFieldAttr)); |
| 1498 | + } |
| 1499 | + |
1420 | 1500 | private IsNotNull isNotNull(Expression field) { |
1421 | 1501 | return new IsNotNull(EMPTY, field); |
1422 | 1502 | } |
|
0 commit comments