Skip to content

Commit ed4df62

Browse files
authored
Fix NPE and add additional graceful error handling (#2687)
* Fix NPE and add additional graceful error handling Signed-off-by: Craig Perkins <cwperx@amazon.com> * Add new lines at end of file Signed-off-by: Craig Perkins <cwperx@amazon.com> * Run spotlessApply Signed-off-by: Craig Perkins <cwperx@amazon.com> * Remove unused import Signed-off-by: Craig Perkins <cwperx@amazon.com> * volatile to final Signed-off-by: Craig Perkins <cwperx@amazon.com> --------- Signed-off-by: Craig Perkins <cwperx@amazon.com>
1 parent bbd43ec commit ed4df62

3 files changed

Lines changed: 239 additions & 6 deletions

File tree

src/main/java/org/opensearch/security/OpenSearchSecurityPlugin.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import java.util.Map;
4646
import java.util.Objects;
4747
import java.util.Set;
48+
import java.util.concurrent.atomic.AtomicReference;
4849
import java.util.function.BiFunction;
4950
import java.util.function.Function;
5051
import java.util.function.Predicate;
@@ -98,7 +99,6 @@
9899
import org.opensearch.index.Index;
99100
import org.opensearch.index.IndexModule;
100101
import org.opensearch.index.cache.query.QueryCache;
101-
import org.opensearch.index.shard.SearchOperationListener;
102102
import org.opensearch.indices.IndicesService;
103103
import org.opensearch.indices.SystemIndexDescriptor;
104104
import org.opensearch.indices.breaker.CircuitBreakerService;
@@ -165,6 +165,7 @@
165165
import org.opensearch.security.ssl.transport.SecuritySSLNettyTransport;
166166
import org.opensearch.security.ssl.util.SSLConfigConstants;
167167
import org.opensearch.security.support.ConfigConstants;
168+
import org.opensearch.security.support.GuardedSearchOperationWrapper;
168169
import org.opensearch.security.support.HeaderHelper;
169170
import org.opensearch.security.support.ModuleInfo;
170171
import org.opensearch.security.support.ReflectionHelper;
@@ -215,7 +216,7 @@ public final class OpenSearchSecurityPlugin extends OpenSearchSecuritySSLPlugin
215216
private final List<String> demoCertHashes = new ArrayList<String>(3);
216217
private volatile SecurityFilter sf;
217218
private volatile IndexResolverReplacer irr;
218-
private volatile NamedXContentRegistry namedXContentRegistry = null;
219+
private final AtomicReference<NamedXContentRegistry> namedXContentRegistry = new AtomicReference<>(NamedXContentRegistry.EMPTY);;
219220
private volatile DlsFlsRequestValve dlsFlsValve = null;
220221
private volatile Salt salt;
221222
private volatile OpensearchDynamicSetting<Boolean> transportPassiveAuthSetting;
@@ -569,11 +570,11 @@ public Weight doCache(Weight weight, QueryCachingPolicy policy) {
569570
}
570571
});
571572

572-
indexModule.addSearchOperationListener(new SearchOperationListener() {
573+
indexModule.addSearchOperationListener(new GuardedSearchOperationWrapper() {
573574

574575
@Override
575576
public void onPreQueryPhase(SearchContext context) {
576-
dlsFlsValve.handleSearchContext(context, threadPool, namedXContentRegistry);
577+
dlsFlsValve.handleSearchContext(context, threadPool, namedXContentRegistry.get());
577578
}
578579

579580
@Override
@@ -643,7 +644,7 @@ public void onQueryPhase(SearchContext searchContext, long tookInNanos) {
643644
}
644645
}
645646
}
646-
});
647+
}.toListener());
647648
}
648649
}
649650

@@ -798,6 +799,7 @@ public Collection<Object> createComponents(Client localClient, ClusterService cl
798799

799800
final PrivilegesInterceptor privilegesInterceptor;
800801

802+
namedXContentRegistry.set(xContentRegistry);
801803
if (SSLConfig.isSslOnlyMode()) {
802804
dlsFlsValve = new DlsFlsRequestValve.NoopDlsFlsRequestValve();
803805
auditLog = new NullAuditLog();
@@ -822,7 +824,7 @@ public Collection<Object> createComponents(Client localClient, ClusterService cl
822824
// DLS-FLS is enabled if not client and not disabled and not SSL only.
823825
final boolean dlsFlsEnabled = !SSLConfig.isSslOnlyMode();
824826
evaluator = new PrivilegesEvaluator(clusterService, threadPool, cr, resolver, auditLog,
825-
settings, privilegesInterceptor, cih, irr, dlsFlsEnabled, namedXContentRegistry);
827+
settings, privilegesInterceptor, cih, irr, dlsFlsEnabled, namedXContentRegistry.get());
826828

827829
sf = new SecurityFilter(settings, evaluator, adminDns, dlsFlsValve, auditLog, threadPool, cs, compatConfig, irr, xffResolver);
828830

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*
8+
* Modifications Copyright OpenSearch Contributors. See
9+
* GitHub history for details.
10+
*/
11+
12+
package org.opensearch.security.support;
13+
14+
import org.apache.logging.log4j.LogManager;
15+
import org.apache.logging.log4j.Logger;
16+
17+
import org.opensearch.index.shard.SearchOperationListener;
18+
import org.opensearch.search.internal.ReaderContext;
19+
import org.opensearch.search.internal.SearchContext;
20+
import org.opensearch.transport.TransportRequest;
21+
22+
/**
23+
* Guarded version of Search Operation Listener to ensure critical request paths succeed
24+
*/
25+
public interface GuardedSearchOperationWrapper {
26+
27+
static final Logger log = LogManager.getLogger(GuardedSearchOperationWrapper.class);
28+
29+
void onPreQueryPhase(final SearchContext context);
30+
31+
void onNewReaderContext(final ReaderContext readerContext);
32+
33+
void onNewScrollContext(final ReaderContext readerContext);
34+
35+
void validateReaderContext(final ReaderContext readerContext, final TransportRequest transportRequest);
36+
37+
void onQueryPhase(final SearchContext searchContext, final long tookInNanos);
38+
39+
default SearchOperationListener toListener() {
40+
return new InnerSearchOperationListener(this);
41+
}
42+
43+
static class InnerSearchOperationListener implements SearchOperationListener {
44+
45+
private GuardedSearchOperationWrapper that;
46+
InnerSearchOperationListener(GuardedSearchOperationWrapper that) {
47+
this.that = that;
48+
}
49+
50+
@Override
51+
public void onPreQueryPhase(final SearchContext searchContext) {
52+
try {
53+
that.onPreQueryPhase(searchContext);
54+
} catch (final Exception e) {
55+
searchContext.setTask(null);
56+
log.error("Cancelled request due to internal error", e);
57+
}
58+
}
59+
60+
@Override
61+
public void onNewReaderContext(final ReaderContext readerContext) {
62+
that.onNewReaderContext(readerContext);
63+
}
64+
65+
@Override
66+
public void onNewScrollContext(final ReaderContext readerContext) {
67+
that.onNewScrollContext(readerContext);
68+
}
69+
70+
@Override
71+
public void validateReaderContext(final ReaderContext readerContext, final TransportRequest transportRequest) {
72+
that.validateReaderContext(readerContext, transportRequest);
73+
}
74+
75+
@Override
76+
public void onQueryPhase(final SearchContext searchContext, final long tookInNanos) {
77+
try {
78+
that.onQueryPhase(searchContext, tookInNanos);
79+
} catch (final Exception e) {
80+
searchContext.setTask(null);
81+
log.error("Cancelled request due to internal error", e);
82+
}
83+
}
84+
}
85+
}
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*
8+
* Modifications Copyright OpenSearch Contributors. See
9+
* GitHub history for details.
10+
*/
11+
package org.opensearch.security.support;
12+
13+
import java.util.concurrent.atomic.AtomicReference;
14+
15+
import org.junit.Test;
16+
17+
import org.opensearch.index.shard.SearchOperationListener;
18+
import org.opensearch.search.internal.ReaderContext;
19+
import org.opensearch.search.internal.SearchContext;
20+
import org.opensearch.transport.TransportRequest;
21+
22+
import static org.hamcrest.MatcherAssert.assertThat;
23+
import static org.hamcrest.Matchers.equalTo;
24+
import static org.hamcrest.Matchers.notNullValue;
25+
import static org.junit.Assert.assertThrows;
26+
import static org.mockito.Mockito.mock;
27+
import static org.mockito.Mockito.verify;
28+
29+
30+
public class GuardedSearchOperationWrapperTest {
31+
32+
@Test
33+
public void onNewReaderContextCanThrowException() {
34+
final String expectedExceptionText = "abcd1234";
35+
36+
DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() {
37+
@Override
38+
public void onNewReaderContext(ReaderContext readerContext) {
39+
throw new RuntimeException(expectedExceptionText);
40+
}
41+
};
42+
43+
final RuntimeException expectedException = assertThrows(RuntimeException.class, testWrapper::exerciseAllMethods);
44+
45+
assertThat(expectedException.getMessage(), equalTo(expectedExceptionText));
46+
}
47+
48+
@Test
49+
public void onNewScrollContextCanThrowException() {
50+
final String expectedExceptionText = "qwerty978";
51+
52+
DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() {
53+
@Override
54+
public void onNewScrollContext(ReaderContext readerContext) {
55+
throw new RuntimeException(expectedExceptionText);
56+
}
57+
};
58+
59+
final RuntimeException expectedException = assertThrows(RuntimeException.class, testWrapper::exerciseAllMethods);
60+
61+
assertThat(expectedException.getMessage(), equalTo(expectedExceptionText));
62+
}
63+
64+
@Test
65+
public void validateReaderContextCanThrowException() {
66+
final String expectedExceptionText = "validationException";
67+
68+
DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() {
69+
@Override
70+
public void validateReaderContext(ReaderContext readerContext, TransportRequest transportRequest) {
71+
throw new RuntimeException(expectedExceptionText);
72+
}
73+
};
74+
75+
final RuntimeException expectedException = assertThrows(RuntimeException.class, testWrapper::exerciseAllMethods);
76+
77+
assertThat(expectedException.getMessage(), equalTo(expectedExceptionText));
78+
}
79+
80+
@Test
81+
public void onPreQueryPhaseCannotThrow() {
82+
AtomicReference<SearchContext> calledSearchContext = new AtomicReference<>();
83+
DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() {
84+
@Override
85+
public void onPreQueryPhase(SearchContext context) {
86+
calledSearchContext.set(context);
87+
throw new RuntimeException("EXCEPTIONAL!");
88+
}
89+
};
90+
91+
testWrapper.exerciseAllMethods();
92+
93+
assertThat(calledSearchContext.get(), notNullValue());
94+
verify(calledSearchContext.get()).setTask(null);
95+
}
96+
97+
@Test
98+
public void onQueryPhaseCannotThrow() {
99+
AtomicReference<SearchContext> calledSearchContext = new AtomicReference<>();
100+
DefaultingGuardedSearchOperationWrapper testWrapper = new DefaultingGuardedSearchOperationWrapper() {
101+
@Override
102+
public void onQueryPhase(SearchContext context, long tookInNanos) {
103+
calledSearchContext.set(context);
104+
throw new RuntimeException("EXCEPTIONAL!");
105+
}
106+
};
107+
108+
testWrapper.exerciseAllMethods();
109+
110+
assertThat(calledSearchContext.get(), notNullValue());
111+
verify(calledSearchContext.get()).setTask(null);
112+
}
113+
114+
/** Only use to make testing easier */
115+
private static class DefaultingGuardedSearchOperationWrapper implements GuardedSearchOperationWrapper {
116+
117+
@Override
118+
public void onNewReaderContext(ReaderContext readerContext) {
119+
}
120+
121+
@Override
122+
public void onNewScrollContext(ReaderContext readerContext) {
123+
}
124+
125+
@Override
126+
public void onPreQueryPhase(SearchContext context) {
127+
}
128+
129+
@Override
130+
public void onQueryPhase(SearchContext searchContext, long tookInNanos) {
131+
}
132+
133+
@Override
134+
public void validateReaderContext(ReaderContext readerContext, TransportRequest transportRequest) {
135+
}
136+
137+
void exerciseAllMethods(){
138+
final SearchOperationListener sol = this.toListener();
139+
sol.onNewReaderContext(mock(ReaderContext.class));
140+
sol.onNewScrollContext(mock(ReaderContext.class));
141+
sol.onPreQueryPhase(mock(SearchContext.class));
142+
sol.onQueryPhase(mock(SearchContext.class), 12345L);
143+
sol.validateReaderContext(mock(ReaderContext.class), mock(TransportRequest.class));
144+
}
145+
}
146+
}

0 commit comments

Comments
 (0)