Skip to content

Commit b77e89d

Browse files
[BUG] opensearch crashes on closed client connection before search reply (#3626) (#3645)
* [BUG] opensearch crashes on closed client connection before search reply Signed-off-by: Andriy Redko <andriy.redko@aiven.io> * Addressing code review comments Signed-off-by: Andriy Redko <andriy.redko@aiven.io> (cherry picked from commit 3dba46e) Co-authored-by: Andriy Redko <andriy.redko@aiven.io> Signed-off-by: Andriy Redko <andriy.redko@aiven.io>
1 parent 5fb97cd commit b77e89d

2 files changed

Lines changed: 162 additions & 5 deletions

File tree

server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,11 @@ private void onShardFailure(final int shardIndex, @Nullable SearchShardTarget sh
454454
}
455455
final int totalOps = this.totalOps.incrementAndGet();
456456
if (totalOps == expectedTotalOps) {
457-
onPhaseDone();
457+
try {
458+
onPhaseDone();
459+
} catch (final Exception ex) {
460+
onPhaseFailure(this, "The phase has failed", ex);
461+
}
458462
} else if (totalOps > expectedTotalOps) {
459463
throw new AssertionError(
460464
"unexpected higher total ops [" + totalOps + "] compared to expected [" + expectedTotalOps + "]",
@@ -559,7 +563,11 @@ private void successfulShardExecution(SearchShardIterator shardsIt) {
559563
}
560564
final int xTotalOps = totalOps.addAndGet(remainingOpsOnIterator);
561565
if (xTotalOps == expectedTotalOps) {
562-
onPhaseDone();
566+
try {
567+
onPhaseDone();
568+
} catch (final Exception ex) {
569+
onPhaseFailure(this, "The phase has failed", ex);
570+
}
563571
} else if (xTotalOps > expectedTotalOps) {
564572
throw new AssertionError(
565573
"unexpected higher total ops [" + xTotalOps + "] compared to expected [" + expectedTotalOps + "]",

server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java

Lines changed: 152 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232

3333
package org.opensearch.action.search;
3434

35+
import org.junit.After;
36+
import org.junit.Before;
3537
import org.opensearch.action.ActionListener;
3638
import org.opensearch.action.OriginalIndices;
3739
import org.opensearch.action.support.IndicesOptions;
@@ -43,25 +45,34 @@
4345
import org.opensearch.index.Index;
4446
import org.opensearch.index.query.MatchAllQueryBuilder;
4547
import org.opensearch.index.shard.ShardId;
48+
import org.opensearch.index.shard.ShardNotFoundException;
4649
import org.opensearch.search.SearchPhaseResult;
4750
import org.opensearch.search.SearchShardTarget;
4851
import org.opensearch.search.internal.AliasFilter;
4952
import org.opensearch.search.internal.InternalSearchResponse;
5053
import org.opensearch.search.internal.ShardSearchContextId;
5154
import org.opensearch.search.internal.ShardSearchRequest;
55+
import org.opensearch.search.query.QuerySearchResult;
5256
import org.opensearch.test.OpenSearchTestCase;
5357
import org.opensearch.transport.Transport;
5458

5559
import java.util.ArrayList;
60+
import java.util.Arrays;
5661
import java.util.Collections;
5762
import java.util.HashSet;
5863
import java.util.List;
5964
import java.util.Set;
65+
import java.util.UUID;
6066
import java.util.concurrent.CopyOnWriteArraySet;
67+
import java.util.concurrent.CountDownLatch;
68+
import java.util.concurrent.ExecutorService;
69+
import java.util.concurrent.Executors;
6170
import java.util.concurrent.TimeUnit;
71+
import java.util.concurrent.atomic.AtomicBoolean;
6272
import java.util.concurrent.atomic.AtomicLong;
6373
import java.util.concurrent.atomic.AtomicReference;
6474
import java.util.function.BiFunction;
75+
import java.util.stream.IntStream;
6576

6677
import static org.hamcrest.Matchers.equalTo;
6778
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -71,13 +82,49 @@ public class AbstractSearchAsyncActionTests extends OpenSearchTestCase {
7182

7283
private final List<Tuple<String, String>> resolvedNodes = new ArrayList<>();
7384
private final Set<ShardSearchContextId> releasedContexts = new CopyOnWriteArraySet<>();
85+
private ExecutorService executor;
86+
87+
@Before
88+
@Override
89+
public void setUp() throws Exception {
90+
super.setUp();
91+
executor = Executors.newFixedThreadPool(1);
92+
}
93+
94+
@After
95+
@Override
96+
public void tearDown() throws Exception {
97+
super.tearDown();
98+
executor.shutdown();
99+
assertTrue(executor.awaitTermination(1, TimeUnit.SECONDS));
100+
}
74101

75102
private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
76103
SearchRequest request,
77104
ArraySearchPhaseResults<SearchPhaseResult> results,
78105
ActionListener<SearchResponse> listener,
79106
final boolean controlled,
80107
final AtomicLong expected
108+
) {
109+
return createAction(
110+
request,
111+
results,
112+
listener,
113+
controlled,
114+
false,
115+
expected,
116+
new SearchShardIterator(null, null, Collections.emptyList(), null)
117+
);
118+
}
119+
120+
private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
121+
SearchRequest request,
122+
ArraySearchPhaseResults<SearchPhaseResult> results,
123+
ActionListener<SearchResponse> listener,
124+
final boolean controlled,
125+
final boolean failExecutePhaseOnShard,
126+
final AtomicLong expected,
127+
final SearchShardIterator... shards
81128
) {
82129
final Runnable runnable;
83130
final TransportSearchAction.SearchTimeProvider timeProvider;
@@ -105,10 +152,10 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
105152
Collections.singletonMap("foo", new AliasFilter(new MatchAllQueryBuilder())),
106153
Collections.singletonMap("foo", 2.0f),
107154
Collections.singletonMap("name", Sets.newHashSet("bar", "baz")),
108-
null,
155+
executor,
109156
request,
110157
listener,
111-
new GroupShardsIterator<>(Collections.singletonList(new SearchShardIterator(null, null, Collections.emptyList(), null))),
158+
new GroupShardsIterator<>(Arrays.asList(shards)),
112159
timeProvider,
113160
ClusterState.EMPTY_STATE,
114161
null,
@@ -126,7 +173,13 @@ protected void executePhaseOnShard(
126173
final SearchShardIterator shardIt,
127174
final SearchShardTarget shard,
128175
final SearchActionListener<SearchPhaseResult> listener
129-
) {}
176+
) {
177+
if (failExecutePhaseOnShard) {
178+
listener.onFailure(new ShardNotFoundException(shardIt.shardId()));
179+
} else {
180+
listener.onResponse(new QuerySearchResult());
181+
}
182+
}
130183

131184
@Override
132185
long buildTookInMillis() {
@@ -328,6 +381,102 @@ private static ArraySearchPhaseResults<SearchPhaseResult> phaseResults(
328381
return phaseResults;
329382
}
330383

384+
public void testOnShardFailurePhaseDoneFailure() throws InterruptedException {
385+
final Index index = new Index("test", UUID.randomUUID().toString());
386+
final CountDownLatch latch = new CountDownLatch(1);
387+
final AtomicBoolean fail = new AtomicBoolean(true);
388+
389+
final SearchShardIterator[] shards = IntStream.range(0, 5 + randomInt(10))
390+
.mapToObj(i -> new SearchShardIterator(null, new ShardId(index, i), Arrays.asList("n1", "n2", "n3"), null, null, null))
391+
.toArray(SearchShardIterator[]::new);
392+
393+
SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true);
394+
searchRequest.setMaxConcurrentShardRequests(1);
395+
396+
final ArraySearchPhaseResults<SearchPhaseResult> queryResult = new ArraySearchPhaseResults<>(shards.length);
397+
AbstractSearchAsyncAction<SearchPhaseResult> action = createAction(
398+
searchRequest,
399+
queryResult,
400+
new ActionListener<SearchResponse>() {
401+
@Override
402+
public void onResponse(SearchResponse response) {
403+
404+
}
405+
406+
@Override
407+
public void onFailure(Exception e) {
408+
if (fail.compareAndSet(true, false)) {
409+
try {
410+
throw new RuntimeException("Simulated exception");
411+
} finally {
412+
executor.submit(() -> latch.countDown());
413+
}
414+
}
415+
}
416+
},
417+
false,
418+
true,
419+
new AtomicLong(),
420+
shards
421+
);
422+
action.run();
423+
assertTrue(latch.await(1, TimeUnit.SECONDS));
424+
425+
InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty();
426+
SearchResponse searchResponse = action.buildSearchResponse(internalSearchResponse, action.buildShardFailures(), null, null);
427+
assertSame(searchResponse.getAggregations(), internalSearchResponse.aggregations());
428+
assertSame(searchResponse.getSuggest(), internalSearchResponse.suggest());
429+
assertSame(searchResponse.getProfileResults(), internalSearchResponse.profile());
430+
assertSame(searchResponse.getHits(), internalSearchResponse.hits());
431+
assertThat(searchResponse.getSuccessfulShards(), equalTo(0));
432+
}
433+
434+
public void testOnShardSuccessPhaseDoneFailure() throws InterruptedException {
435+
final Index index = new Index("test", UUID.randomUUID().toString());
436+
final CountDownLatch latch = new CountDownLatch(1);
437+
final AtomicBoolean fail = new AtomicBoolean(true);
438+
439+
final SearchShardIterator[] shards = IntStream.range(0, 5 + randomInt(10))
440+
.mapToObj(i -> new SearchShardIterator(null, new ShardId(index, i), Arrays.asList("n1", "n2", "n3"), null, null, null))
441+
.toArray(SearchShardIterator[]::new);
442+
443+
SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true);
444+
searchRequest.setMaxConcurrentShardRequests(1);
445+
446+
final ArraySearchPhaseResults<SearchPhaseResult> queryResult = new ArraySearchPhaseResults<>(shards.length);
447+
AbstractSearchAsyncAction<SearchPhaseResult> action = createAction(
448+
searchRequest,
449+
queryResult,
450+
new ActionListener<SearchResponse>() {
451+
@Override
452+
public void onResponse(SearchResponse response) {
453+
if (fail.compareAndSet(true, false)) {
454+
throw new RuntimeException("Simulated exception");
455+
}
456+
}
457+
458+
@Override
459+
public void onFailure(Exception e) {
460+
executor.submit(() -> latch.countDown());
461+
}
462+
},
463+
false,
464+
false,
465+
new AtomicLong(),
466+
shards
467+
);
468+
action.run();
469+
assertTrue(latch.await(1, TimeUnit.SECONDS));
470+
471+
InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty();
472+
SearchResponse searchResponse = action.buildSearchResponse(internalSearchResponse, action.buildShardFailures(), null, null);
473+
assertSame(searchResponse.getAggregations(), internalSearchResponse.aggregations());
474+
assertSame(searchResponse.getSuggest(), internalSearchResponse.suggest());
475+
assertSame(searchResponse.getProfileResults(), internalSearchResponse.profile());
476+
assertSame(searchResponse.getHits(), internalSearchResponse.hits());
477+
assertThat(searchResponse.getSuccessfulShards(), equalTo(shards.length));
478+
}
479+
331480
private static final class PhaseResult extends SearchPhaseResult {
332481
PhaseResult(ShardSearchContextId contextId) {
333482
this.contextId = contextId;

0 commit comments

Comments
 (0)