Skip to content

Commit e7b8a7c

Browse files
Move mutable entities such as shard counts out of MutableSearchResponse
1 parent 5760dd2 commit e7b8a7c

2 files changed

Lines changed: 142 additions & 77 deletions

File tree

x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java

Lines changed: 94 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.elasticsearch.action.search.ShardSearchFailure;
2525
import org.elasticsearch.action.search.TransportSearchAction;
2626
import org.elasticsearch.client.internal.Client;
27+
import org.elasticsearch.common.util.concurrent.AtomicArray;
2728
import org.elasticsearch.core.Releasable;
2829
import org.elasticsearch.core.Releasables;
2930
import org.elasticsearch.core.TimeValue;
@@ -75,6 +76,8 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask, Releasable
7576

7677
private final SetOnce<MutableSearchResponse> searchResponse = new SetOnce<>();
7778

79+
private final ShardsInfo shardsInfo;
80+
7881
/**
7982
* Creates an instance of {@link AsyncSearchTask}.
8083
*
@@ -112,6 +115,7 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask, Releasable
112115
this.aggReduceContextSupplier = aggReduceContextSupplierFactory.apply(this::isCancelled);
113116
this.progressListener = new Listener();
114117
setProgressListener(progressListener);
118+
shardsInfo = new ShardsInfo();
115119
}
116120

117121
/**
@@ -352,7 +356,7 @@ private AsyncSearchResponse getResponse(boolean restoreResponseHeaders) {
352356
ExceptionsHelper.status(e),
353357
e
354358
);
355-
asyncSearchResponse = mutableSearchResponse.toAsyncSearchResponse(this, expirationTimeMillis, exception);
359+
asyncSearchResponse = mutableSearchResponse.toAsyncSearchResponse(shardsInfo, this, expirationTimeMillis, exception);
356360
}
357361
return asyncSearchResponse;
358362
}
@@ -373,6 +377,7 @@ public static AsyncStatusResponse getStatusResponse(AsyncSearchTask asyncTask) {
373377
MutableSearchResponse mutableSearchResponse = asyncTask.searchResponse.get();
374378
assert mutableSearchResponse != null;
375379
return mutableSearchResponse.toStatusResponse(
380+
asyncTask.shardsInfo,
376381
asyncTask.searchId.getEncoded(),
377382
asyncTask.getStartTime(),
378383
asyncTask.expirationTimeMillis
@@ -384,11 +389,19 @@ public void close() {
384389
Releasables.close(searchResponse.get());
385390
}
386391

392+
public ShardsInfo getShardsInfo() {
393+
return shardsInfo;
394+
}
395+
387396
class Listener extends SearchProgressActionListener {
388397

389398
// needed when there's a single coordinator for all CCS search phases (minimize_roundtrips=false)
390399
private CCSSingleCoordinatorSearchProgressListener delegate;
391400

401+
Listener() {
402+
searchResponse.set(new MutableSearchResponse(threadPool.getThreadContext()));
403+
}
404+
392405
@Override
393406
protected void onQueryResult(int shardIndex, QuerySearchResult queryResult) {
394407
checkCancellation();
@@ -424,7 +437,8 @@ protected void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exc
424437
.addQueryFailure(
425438
shardIndex,
426439
// the nodeId is null if all replicas of this shard failed
427-
new ShardSearchFailure(exc, shardTarget.getNodeId() != null ? shardTarget : null)
440+
new ShardSearchFailure(exc, shardTarget.getNodeId() != null ? shardTarget : null),
441+
shardsInfo.getQueryFailures()
428442
);
429443
}
430444

@@ -467,14 +481,7 @@ protected void onListShards(
467481
delegate = new CCSSingleCoordinatorSearchProgressListener();
468482
delegate.onListShards(shards, skipped, clusters, fetchPhase, timeProvider);
469483
}
470-
471-
MutableSearchResponse mutableSearchResponse = searchResponse.get();
472-
if (mutableSearchResponse == null) {
473-
mutableSearchResponse = new MutableSearchResponse(clusters, threadPool.getThreadContext());
474-
searchResponse.set(mutableSearchResponse);
475-
}
476-
477-
mutableSearchResponse.updateShardsCount(shards.size() + skipped.size(), skipped.size());
484+
shardsInfo.setShardDetails(shards.size() + skipped.size(), skipped.size(), clusters);
478485
executeInitListeners();
479486
}
480487

@@ -501,7 +508,7 @@ public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, Inter
501508
*/
502509
reducedAggs = () -> InternalAggregations.topLevelReduce(singletonList(aggregations), aggReduceContextSupplier.get());
503510
}
504-
searchResponse.get().updatePartialResponse(shards.size(), totalHits, reducedAggs, reducePhase);
511+
searchResponse.get().updatePartialResponse(shards.size(), shardsInfo.getSkippedShards(), totalHits, reducedAggs, reducePhase);
505512
}
506513

507514
/**
@@ -515,7 +522,8 @@ public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, Interna
515522
if (delegate != null) {
516523
delegate.onFinalReduce(shards, totalHits, aggregations, reducePhase);
517524
}
518-
searchResponse.get().updatePartialResponse(shards.size(), totalHits, () -> aggregations, reducePhase);
525+
searchResponse.get()
526+
.updatePartialResponse(shards.size(), shardsInfo.getSkippedShards(), totalHits, () -> aggregations, reducePhase);
519527
}
520528

521529
/**
@@ -528,26 +536,24 @@ public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, Interna
528536
@Override
529537
public void onClusterResponseMinimizeRoundtrips(String clusterAlias, SearchResponse clusterResponse) {
530538
// no need to call the delegate progress listener, since this method is only called for minimize_roundtrips=true
531-
if (searchResponse.get() == null) {
532-
searchResponse.set(new MutableSearchResponse(getResponseClusters(), threadPool.getThreadContext()));
533-
}
534-
535539
searchResponse.get().updateResponseMinimizeRoundtrips(clusterAlias, clusterResponse);
536540
}
537541

538542
@Override
539543
public void onResponse(SearchResponse response) {
540-
searchResponse.get().updateFinalResponse(response, ccsMinimizeRoundtrips);
544+
searchResponse.get()
545+
.updateFinalResponse(shardsInfo.getTotalShards(), shardsInfo.getSkippedShards(), response, ccsMinimizeRoundtrips);
541546
executeCompletionListeners();
542547
}
543548

544549
@Override
545550
public void onFailure(Exception exc) {
546551
// if the failure occurred before calling onListShards
547-
var r = new MutableSearchResponse(null, threadPool.getThreadContext());
552+
var r = new MutableSearchResponse(threadPool.getThreadContext());
548553
if (searchResponse.trySet(r) == false) {
549-
r.updateShardsCount(-1, -1);
550554
r.close();
555+
} else {
556+
shardsInfo.setShardDetails(-1, -1, null);
551557
}
552558

553559
searchResponse.get()
@@ -557,6 +563,75 @@ public void onFailure(Exception exc) {
557563
}
558564
}
559565

566+
/**
567+
* Captures shards information such as number of total shards and skipped shards once
568+
* they're made available via onListShards().
569+
*/
570+
class ShardsInfo {
571+
private int totalShards;
572+
private int skippedShards;
573+
private AtomicArray<ShardSearchFailure> queryFailures;
574+
private Clusters clusters;
575+
576+
/**
577+
*
578+
* @param totalShards The number of shards that participate in the request, or -1 to indicate a failure.
579+
* @param skippedShards The number of skipped shards, or -1 to indicate a failure.
580+
* @param clusters The remote clusters statistics.
581+
*/
582+
public void setShardDetails(int totalShards, int skippedShards, Clusters clusters) {
583+
this.totalShards = totalShards;
584+
this.skippedShards = skippedShards;
585+
this.queryFailures = totalShards == -1 ? null : new AtomicArray<>(totalShards - skippedShards);
586+
this.clusters = clusters;
587+
}
588+
589+
/**
590+
* @return The total number of shards participating in the search.
591+
*/
592+
public int getTotalShards() {
593+
return totalShards;
594+
}
595+
596+
/**
597+
* @return The total number of shards skipped.
598+
*/
599+
public int getSkippedShards() {
600+
return skippedShards;
601+
}
602+
603+
/**
604+
* @return All the query failures occurred.
605+
*/
606+
public AtomicArray<ShardSearchFailure> getQueryFailures() {
607+
return queryFailures;
608+
}
609+
610+
/**
611+
* @return Clusters participating in the search.
612+
*/
613+
public Clusters getClusters() {
614+
return clusters;
615+
}
616+
617+
/**
618+
* @return An array that holds query failures represented by {@link ShardSearchFailure}-s.
619+
*/
620+
public ShardSearchFailure[] buildQueryFailures() {
621+
if (queryFailures == null) {
622+
return ShardSearchFailure.EMPTY_ARRAY;
623+
}
624+
List<ShardSearchFailure> failures = new ArrayList<>();
625+
for (int i = 0; i < queryFailures.length(); i++) {
626+
ShardSearchFailure shardSearchFailure = queryFailures.get(i);
627+
if (shardSearchFailure != null) {
628+
failures.add(shardSearchFailure);
629+
}
630+
}
631+
return failures.toArray(ShardSearchFailure[]::new);
632+
}
633+
}
634+
560635
@Override
561636
public boolean isAsync() {
562637
return true;

0 commit comments

Comments
 (0)