2424import org .elasticsearch .action .search .ShardSearchFailure ;
2525import org .elasticsearch .action .search .TransportSearchAction ;
2626import org .elasticsearch .client .internal .Client ;
27+ import org .elasticsearch .common .util .concurrent .AtomicArray ;
2728import org .elasticsearch .core .Releasable ;
2829import org .elasticsearch .core .Releasables ;
2930import org .elasticsearch .core .TimeValue ;
@@ -75,6 +76,8 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask, Releasable
7576
7677 private final SetOnce <MutableSearchResponse > searchResponse = new SetOnce <>();
7778
79+ private final ShardsInfo shardsInfo ;
80+
7881 /**
7982 * Creates an instance of {@link AsyncSearchTask}.
8083 *
@@ -112,6 +115,7 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask, Releasable
112115 this .aggReduceContextSupplier = aggReduceContextSupplierFactory .apply (this ::isCancelled );
113116 this .progressListener = new Listener ();
114117 setProgressListener (progressListener );
118+ shardsInfo = new ShardsInfo ();
115119 }
116120
117121 /**
@@ -352,7 +356,7 @@ private AsyncSearchResponse getResponse(boolean restoreResponseHeaders) {
352356 ExceptionsHelper .status (e ),
353357 e
354358 );
355- asyncSearchResponse = mutableSearchResponse .toAsyncSearchResponse (this , expirationTimeMillis , exception );
359+ asyncSearchResponse = mutableSearchResponse .toAsyncSearchResponse (shardsInfo , this , expirationTimeMillis , exception );
356360 }
357361 return asyncSearchResponse ;
358362 }
@@ -373,6 +377,7 @@ public static AsyncStatusResponse getStatusResponse(AsyncSearchTask asyncTask) {
373377 MutableSearchResponse mutableSearchResponse = asyncTask .searchResponse .get ();
374378 assert mutableSearchResponse != null ;
375379 return mutableSearchResponse .toStatusResponse (
380+ asyncTask .shardsInfo ,
376381 asyncTask .searchId .getEncoded (),
377382 asyncTask .getStartTime (),
378383 asyncTask .expirationTimeMillis
@@ -384,11 +389,19 @@ public void close() {
384389 Releasables .close (searchResponse .get ());
385390 }
386391
392+ public ShardsInfo getShardsInfo () {
393+ return shardsInfo ;
394+ }
395+
387396 class Listener extends SearchProgressActionListener {
388397
389398 // needed when there's a single coordinator for all CCS search phases (minimize_roundtrips=false)
390399 private CCSSingleCoordinatorSearchProgressListener delegate ;
391400
401+ Listener () {
402+ searchResponse .set (new MutableSearchResponse (threadPool .getThreadContext ()));
403+ }
404+
392405 @ Override
393406 protected void onQueryResult (int shardIndex , QuerySearchResult queryResult ) {
394407 checkCancellation ();
@@ -424,7 +437,8 @@ protected void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exc
424437 .addQueryFailure (
425438 shardIndex ,
426439 // the nodeId is null if all replicas of this shard failed
427- new ShardSearchFailure (exc , shardTarget .getNodeId () != null ? shardTarget : null )
440+ new ShardSearchFailure (exc , shardTarget .getNodeId () != null ? shardTarget : null ),
441+ shardsInfo .getQueryFailures ()
428442 );
429443 }
430444
@@ -467,14 +481,7 @@ protected void onListShards(
467481 delegate = new CCSSingleCoordinatorSearchProgressListener ();
468482 delegate .onListShards (shards , skipped , clusters , fetchPhase , timeProvider );
469483 }
470-
471- MutableSearchResponse mutableSearchResponse = searchResponse .get ();
472- if (mutableSearchResponse == null ) {
473- mutableSearchResponse = new MutableSearchResponse (clusters , threadPool .getThreadContext ());
474- searchResponse .set (mutableSearchResponse );
475- }
476-
477- mutableSearchResponse .updateShardsCount (shards .size () + skipped .size (), skipped .size ());
484+ shardsInfo .setShardDetails (shards .size () + skipped .size (), skipped .size (), clusters );
478485 executeInitListeners ();
479486 }
480487
@@ -501,7 +508,7 @@ public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, Inter
501508 */
502509 reducedAggs = () -> InternalAggregations .topLevelReduce (singletonList (aggregations ), aggReduceContextSupplier .get ());
503510 }
504- searchResponse .get ().updatePartialResponse (shards .size (), totalHits , reducedAggs , reducePhase );
511+ searchResponse .get ().updatePartialResponse (shards .size (), shardsInfo . getSkippedShards (), totalHits , reducedAggs , reducePhase );
505512 }
506513
507514 /**
@@ -515,7 +522,8 @@ public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, Interna
515522 if (delegate != null ) {
516523 delegate .onFinalReduce (shards , totalHits , aggregations , reducePhase );
517524 }
518- searchResponse .get ().updatePartialResponse (shards .size (), totalHits , () -> aggregations , reducePhase );
525+ searchResponse .get ()
526+ .updatePartialResponse (shards .size (), shardsInfo .getSkippedShards (), totalHits , () -> aggregations , reducePhase );
519527 }
520528
521529 /**
@@ -528,26 +536,24 @@ public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, Interna
528536 @ Override
529537 public void onClusterResponseMinimizeRoundtrips (String clusterAlias , SearchResponse clusterResponse ) {
530538 // no need to call the delegate progress listener, since this method is only called for minimize_roundtrips=true
531- if (searchResponse .get () == null ) {
532- searchResponse .set (new MutableSearchResponse (getResponseClusters (), threadPool .getThreadContext ()));
533- }
534-
535539 searchResponse .get ().updateResponseMinimizeRoundtrips (clusterAlias , clusterResponse );
536540 }
537541
538542 @ Override
539543 public void onResponse (SearchResponse response ) {
540- searchResponse .get ().updateFinalResponse (response , ccsMinimizeRoundtrips );
544+ searchResponse .get ()
545+ .updateFinalResponse (shardsInfo .getTotalShards (), shardsInfo .getSkippedShards (), response , ccsMinimizeRoundtrips );
541546 executeCompletionListeners ();
542547 }
543548
544549 @ Override
545550 public void onFailure (Exception exc ) {
546551 // if the failure occurred before calling onListShards
547- var r = new MutableSearchResponse (null , threadPool .getThreadContext ());
552+ var r = new MutableSearchResponse (threadPool .getThreadContext ());
548553 if (searchResponse .trySet (r ) == false ) {
549- r .updateShardsCount (-1 , -1 );
550554 r .close ();
555+ } else {
556+ shardsInfo .setShardDetails (-1 , -1 , null );
551557 }
552558
553559 searchResponse .get ()
@@ -557,6 +563,75 @@ public void onFailure(Exception exc) {
557563 }
558564 }
559565
566+ /**
567+ * Captures shards information such as number of total shards and skipped shards once
568+ * they're made available via onListShards().
569+ */
570+ class ShardsInfo {
571+ private int totalShards ;
572+ private int skippedShards ;
573+ private AtomicArray <ShardSearchFailure > queryFailures ;
574+ private Clusters clusters ;
575+
576+ /**
577+ *
578+ * @param totalShards The number of shards that participate in the request, or -1 to indicate a failure.
579+ * @param skippedShards The number of skipped shards, or -1 to indicate a failure.
580+ * @param clusters The remote clusters statistics.
581+ */
582+ public void setShardDetails (int totalShards , int skippedShards , Clusters clusters ) {
583+ this .totalShards = totalShards ;
584+ this .skippedShards = skippedShards ;
585+ this .queryFailures = totalShards == -1 ? null : new AtomicArray <>(totalShards - skippedShards );
586+ this .clusters = clusters ;
587+ }
588+
589+ /**
590+ * @return The total number of shards participating in the search.
591+ */
592+ public int getTotalShards () {
593+ return totalShards ;
594+ }
595+
596+ /**
597+ * @return The total number of shards skipped.
598+ */
599+ public int getSkippedShards () {
600+ return skippedShards ;
601+ }
602+
603+ /**
604+ * @return All the query failures occurred.
605+ */
606+ public AtomicArray <ShardSearchFailure > getQueryFailures () {
607+ return queryFailures ;
608+ }
609+
610+ /**
611+ * @return Clusters participating in the search.
612+ */
613+ public Clusters getClusters () {
614+ return clusters ;
615+ }
616+
617+ /**
618+ * @return An array that holds query failures represented by {@link ShardSearchFailure}-s.
619+ */
620+ public ShardSearchFailure [] buildQueryFailures () {
621+ if (queryFailures == null ) {
622+ return ShardSearchFailure .EMPTY_ARRAY ;
623+ }
624+ List <ShardSearchFailure > failures = new ArrayList <>();
625+ for (int i = 0 ; i < queryFailures .length (); i ++) {
626+ ShardSearchFailure shardSearchFailure = queryFailures .get (i );
627+ if (shardSearchFailure != null ) {
628+ failures .add (shardSearchFailure );
629+ }
630+ }
631+ return failures .toArray (ShardSearchFailure []::new );
632+ }
633+ }
634+
560635 @ Override
561636 public boolean isAsync () {
562637 return true ;
0 commit comments