3232
3333package org .opensearch .action .search ;
3434
35+ import org .junit .After ;
36+ import org .junit .Before ;
3537import org .opensearch .action .ActionListener ;
3638import org .opensearch .action .OriginalIndices ;
3739import org .opensearch .action .support .IndicesOptions ;
4345import org .opensearch .index .Index ;
4446import org .opensearch .index .query .MatchAllQueryBuilder ;
4547import org .opensearch .index .shard .ShardId ;
48+ import org .opensearch .index .shard .ShardNotFoundException ;
4649import org .opensearch .search .SearchPhaseResult ;
4750import org .opensearch .search .SearchShardTarget ;
4851import org .opensearch .search .internal .AliasFilter ;
4952import org .opensearch .search .internal .InternalSearchResponse ;
5053import org .opensearch .search .internal .ShardSearchContextId ;
5154import org .opensearch .search .internal .ShardSearchRequest ;
55+ import org .opensearch .search .query .QuerySearchResult ;
5256import org .opensearch .test .OpenSearchTestCase ;
5357import org .opensearch .transport .Transport ;
5458
5559import java .util .ArrayList ;
60+ import java .util .Arrays ;
5661import java .util .Collections ;
5762import java .util .HashSet ;
5863import java .util .List ;
5964import java .util .Set ;
65+ import java .util .UUID ;
6066import java .util .concurrent .CopyOnWriteArraySet ;
67+ import java .util .concurrent .CountDownLatch ;
68+ import java .util .concurrent .ExecutorService ;
69+ import java .util .concurrent .Executors ;
6170import java .util .concurrent .TimeUnit ;
71+ import java .util .concurrent .atomic .AtomicBoolean ;
6272import java .util .concurrent .atomic .AtomicLong ;
6373import java .util .concurrent .atomic .AtomicReference ;
6474import java .util .function .BiFunction ;
75+ import java .util .stream .IntStream ;
6576
6677import static org .hamcrest .Matchers .equalTo ;
6778import static org .hamcrest .Matchers .greaterThanOrEqualTo ;
@@ -71,13 +82,49 @@ public class AbstractSearchAsyncActionTests extends OpenSearchTestCase {
7182
7283 private final List <Tuple <String , String >> resolvedNodes = new ArrayList <>();
7384 private final Set <ShardSearchContextId > releasedContexts = new CopyOnWriteArraySet <>();
85+ private ExecutorService executor ;
86+
87+ @ Before
88+ @ Override
89+ public void setUp () throws Exception {
90+ super .setUp ();
91+ executor = Executors .newFixedThreadPool (1 );
92+ }
93+
94+ @ After
95+ @ Override
96+ public void tearDown () throws Exception {
97+ super .tearDown ();
98+ executor .shutdown ();
99+ assertTrue (executor .awaitTermination (1 , TimeUnit .SECONDS ));
100+ }
74101
75102 private AbstractSearchAsyncAction <SearchPhaseResult > createAction (
76103 SearchRequest request ,
77104 ArraySearchPhaseResults <SearchPhaseResult > results ,
78105 ActionListener <SearchResponse > listener ,
79106 final boolean controlled ,
80107 final AtomicLong expected
108+ ) {
109+ return createAction (
110+ request ,
111+ results ,
112+ listener ,
113+ controlled ,
114+ false ,
115+ expected ,
116+ new SearchShardIterator (null , null , Collections .emptyList (), null )
117+ );
118+ }
119+
120+ private AbstractSearchAsyncAction <SearchPhaseResult > createAction (
121+ SearchRequest request ,
122+ ArraySearchPhaseResults <SearchPhaseResult > results ,
123+ ActionListener <SearchResponse > listener ,
124+ final boolean controlled ,
125+ final boolean failExecutePhaseOnShard ,
126+ final AtomicLong expected ,
127+ final SearchShardIterator ... shards
81128 ) {
82129 final Runnable runnable ;
83130 final TransportSearchAction .SearchTimeProvider timeProvider ;
@@ -105,10 +152,10 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
105152 Collections .singletonMap ("foo" , new AliasFilter (new MatchAllQueryBuilder ())),
106153 Collections .singletonMap ("foo" , 2.0f ),
107154 Collections .singletonMap ("name" , Sets .newHashSet ("bar" , "baz" )),
108- null ,
155+ executor ,
109156 request ,
110157 listener ,
111- new GroupShardsIterator <>(Collections . singletonList ( new SearchShardIterator ( null , null , Collections . emptyList (), null ) )),
158+ new GroupShardsIterator <>(Arrays . asList ( shards )),
112159 timeProvider ,
113160 ClusterState .EMPTY_STATE ,
114161 null ,
@@ -126,7 +173,13 @@ protected void executePhaseOnShard(
126173 final SearchShardIterator shardIt ,
127174 final SearchShardTarget shard ,
128175 final SearchActionListener <SearchPhaseResult > listener
129- ) {}
176+ ) {
177+ if (failExecutePhaseOnShard ) {
178+ listener .onFailure (new ShardNotFoundException (shardIt .shardId ()));
179+ } else {
180+ listener .onResponse (new QuerySearchResult ());
181+ }
182+ }
130183
131184 @ Override
132185 long buildTookInMillis () {
@@ -328,6 +381,102 @@ private static ArraySearchPhaseResults<SearchPhaseResult> phaseResults(
328381 return phaseResults ;
329382 }
330383
384+ public void testOnShardFailurePhaseDoneFailure () throws InterruptedException {
385+ final Index index = new Index ("test" , UUID .randomUUID ().toString ());
386+ final CountDownLatch latch = new CountDownLatch (1 );
387+ final AtomicBoolean fail = new AtomicBoolean (true );
388+
389+ final SearchShardIterator [] shards = IntStream .range (0 , 5 + randomInt (10 ))
390+ .mapToObj (i -> new SearchShardIterator (null , new ShardId (index , i ), Arrays .asList ("n1" , "n2" , "n3" ), null , null , null ))
391+ .toArray (SearchShardIterator []::new );
392+
393+ SearchRequest searchRequest = new SearchRequest ().allowPartialSearchResults (true );
394+ searchRequest .setMaxConcurrentShardRequests (1 );
395+
396+ final ArraySearchPhaseResults <SearchPhaseResult > queryResult = new ArraySearchPhaseResults <>(shards .length );
397+ AbstractSearchAsyncAction <SearchPhaseResult > action = createAction (
398+ searchRequest ,
399+ queryResult ,
400+ new ActionListener <SearchResponse >() {
401+ @ Override
402+ public void onResponse (SearchResponse response ) {
403+
404+ }
405+
406+ @ Override
407+ public void onFailure (Exception e ) {
408+ if (fail .compareAndSet (true , false )) {
409+ try {
410+ throw new RuntimeException ("Simulated exception" );
411+ } finally {
412+ executor .submit (() -> latch .countDown ());
413+ }
414+ }
415+ }
416+ },
417+ false ,
418+ true ,
419+ new AtomicLong (),
420+ shards
421+ );
422+ action .run ();
423+ assertTrue (latch .await (1 , TimeUnit .SECONDS ));
424+
425+ InternalSearchResponse internalSearchResponse = InternalSearchResponse .empty ();
426+ SearchResponse searchResponse = action .buildSearchResponse (internalSearchResponse , action .buildShardFailures (), null , null );
427+ assertSame (searchResponse .getAggregations (), internalSearchResponse .aggregations ());
428+ assertSame (searchResponse .getSuggest (), internalSearchResponse .suggest ());
429+ assertSame (searchResponse .getProfileResults (), internalSearchResponse .profile ());
430+ assertSame (searchResponse .getHits (), internalSearchResponse .hits ());
431+ assertThat (searchResponse .getSuccessfulShards (), equalTo (0 ));
432+ }
433+
434+ public void testOnShardSuccessPhaseDoneFailure () throws InterruptedException {
435+ final Index index = new Index ("test" , UUID .randomUUID ().toString ());
436+ final CountDownLatch latch = new CountDownLatch (1 );
437+ final AtomicBoolean fail = new AtomicBoolean (true );
438+
439+ final SearchShardIterator [] shards = IntStream .range (0 , 5 + randomInt (10 ))
440+ .mapToObj (i -> new SearchShardIterator (null , new ShardId (index , i ), Arrays .asList ("n1" , "n2" , "n3" ), null , null , null ))
441+ .toArray (SearchShardIterator []::new );
442+
443+ SearchRequest searchRequest = new SearchRequest ().allowPartialSearchResults (true );
444+ searchRequest .setMaxConcurrentShardRequests (1 );
445+
446+ final ArraySearchPhaseResults <SearchPhaseResult > queryResult = new ArraySearchPhaseResults <>(shards .length );
447+ AbstractSearchAsyncAction <SearchPhaseResult > action = createAction (
448+ searchRequest ,
449+ queryResult ,
450+ new ActionListener <SearchResponse >() {
451+ @ Override
452+ public void onResponse (SearchResponse response ) {
453+ if (fail .compareAndSet (true , false )) {
454+ throw new RuntimeException ("Simulated exception" );
455+ }
456+ }
457+
458+ @ Override
459+ public void onFailure (Exception e ) {
460+ executor .submit (() -> latch .countDown ());
461+ }
462+ },
463+ false ,
464+ false ,
465+ new AtomicLong (),
466+ shards
467+ );
468+ action .run ();
469+ assertTrue (latch .await (1 , TimeUnit .SECONDS ));
470+
471+ InternalSearchResponse internalSearchResponse = InternalSearchResponse .empty ();
472+ SearchResponse searchResponse = action .buildSearchResponse (internalSearchResponse , action .buildShardFailures (), null , null );
473+ assertSame (searchResponse .getAggregations (), internalSearchResponse .aggregations ());
474+ assertSame (searchResponse .getSuggest (), internalSearchResponse .suggest ());
475+ assertSame (searchResponse .getProfileResults (), internalSearchResponse .profile ());
476+ assertSame (searchResponse .getHits (), internalSearchResponse .hits ());
477+ assertThat (searchResponse .getSuccessfulShards (), equalTo (shards .length ));
478+ }
479+
331480 private static final class PhaseResult extends SearchPhaseResult {
332481 PhaseResult (ShardSearchContextId contextId ) {
333482 this .contextId = contextId ;
0 commit comments