Skip to content

Commit 222dea9

Browse files
committed
Preserve context after remote license check
1 parent 901acdc commit 222dea9

2 files changed

Lines changed: 126 additions & 25 deletions

File tree

x-pack/plugin/core/src/main/java/org/elasticsearch/license/RemoteClusterLicenseChecker.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import org.elasticsearch.ElasticsearchException;
1010
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.action.support.ContextPreservingActionListener;
1112
import org.elasticsearch.client.Client;
1213
import org.elasticsearch.common.util.concurrent.ThreadContext;
1314
import org.elasticsearch.protocol.xpack.XPackInfoRequest;
@@ -187,16 +188,18 @@ public void onFailure(final Exception e) {
187188

188189
private void remoteClusterLicense(final String clusterAlias, final ActionListener<XPackInfoResponse> listener) {
189190
final ThreadContext threadContext = client.threadPool().getThreadContext();
191+
final ContextPreservingActionListener<XPackInfoResponse> contextPreservingActionListener =
192+
new ContextPreservingActionListener<>(threadContext.newRestorableContext(false), listener);
190193
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
191194
// we stash any context here since this is an internal execution and should not leak any existing context information
192195
threadContext.markAsSystemContext();
193196

194197
final XPackInfoRequest request = new XPackInfoRequest();
195198
request.setCategories(EnumSet.of(XPackInfoRequest.Category.LICENSE));
196199
try {
197-
client.getRemoteClusterClient(clusterAlias).execute(XPackInfoAction.INSTANCE, request, listener);
200+
client.getRemoteClusterClient(clusterAlias).execute(XPackInfoAction.INSTANCE, request, contextPreservingActionListener);
198201
} catch (final Exception e) {
199-
listener.onFailure(e);
202+
contextPreservingActionListener.onFailure(e);
200203
}
201204
}
202205
}

x-pack/plugin/core/src/test/java/org/elasticsearch/license/RemoteClusterLicenseCheckerTests.java

Lines changed: 121 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
import org.elasticsearch.protocol.xpack.XPackInfoResponse;
1515
import org.elasticsearch.protocol.xpack.license.LicenseStatus;
1616
import org.elasticsearch.test.ESTestCase;
17+
import org.elasticsearch.threadpool.TestThreadPool;
1718
import org.elasticsearch.threadpool.ThreadPool;
1819
import org.elasticsearch.xpack.core.action.XPackInfoAction;
1920

2021
import java.util.ArrayList;
2122
import java.util.Arrays;
2223
import java.util.Collections;
2324
import java.util.List;
25+
import java.util.concurrent.atomic.AtomicBoolean;
2426
import java.util.concurrent.atomic.AtomicInteger;
2527
import java.util.concurrent.atomic.AtomicReference;
2628
import java.util.function.Consumer;
@@ -92,15 +94,15 @@ public void testCheckRemoteClusterLicensesGivenCompatibleLicenses() {
9294
final AtomicInteger index = new AtomicInteger();
9395
final List<XPackInfoResponse> responses = new ArrayList<>();
9496

95-
final Client client = createMockClient();
97+
final ThreadPool threadPool = createMockThreadPool();
98+
final Client client = createMockClient(threadPool);
9699
doAnswer(invocationMock -> {
97100
@SuppressWarnings("unchecked") ActionListener<XPackInfoResponse> listener =
98101
(ActionListener<XPackInfoResponse>) invocationMock.getArguments()[2];
99102
listener.onResponse(responses.get(index.getAndIncrement()));
100103
return null;
101104
}).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any());
102105

103-
104106
final List<String> remoteClusterAliases = Arrays.asList("valid1", "valid2", "valid3");
105107
responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
106108
responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
@@ -110,8 +112,9 @@ public void testCheckRemoteClusterLicensesGivenCompatibleLicenses() {
110112
new RemoteClusterLicenseChecker(client, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial);
111113
final AtomicReference<RemoteClusterLicenseChecker.LicenseCheck> licenseCheck = new AtomicReference<>();
112114

113-
licenseChecker.checkRemoteClusterLicenses(remoteClusterAliases,
114-
new ActionListener<RemoteClusterLicenseChecker.LicenseCheck>() {
115+
licenseChecker.checkRemoteClusterLicenses(
116+
remoteClusterAliases,
117+
doubleInvocationProtectingListener(new ActionListener<RemoteClusterLicenseChecker.LicenseCheck>() {
115118

116119
@Override
117120
public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) {
@@ -123,7 +126,7 @@ public void onFailure(final Exception e) {
123126
fail(e.getMessage());
124127
}
125128

126-
});
129+
}));
127130

128131
verify(client, times(3)).execute(same(XPackInfoAction.INSTANCE), any(), any());
129132
assertNotNull(licenseCheck.get());
@@ -138,7 +141,8 @@ public void testCheckRemoteClusterLicensesGivenIncompatibleLicense() {
138141
responses.add(new XPackInfoResponse(null, createBasicLicenseResponse(), null));
139142
responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
140143

141-
final Client client = createMockClient();
144+
final ThreadPool threadPool = createMockThreadPool();
145+
final Client client = createMockClient(threadPool);
142146
doAnswer(invocationMock -> {
143147
@SuppressWarnings("unchecked") ActionListener<XPackInfoResponse> listener =
144148
(ActionListener<XPackInfoResponse>) invocationMock.getArguments()[2];
@@ -152,7 +156,7 @@ public void testCheckRemoteClusterLicensesGivenIncompatibleLicense() {
152156

153157
licenseChecker.checkRemoteClusterLicenses(
154158
remoteClusterAliases,
155-
new ActionListener<RemoteClusterLicenseChecker.LicenseCheck>() {
159+
doubleInvocationProtectingListener(new ActionListener<RemoteClusterLicenseChecker.LicenseCheck>() {
156160

157161
@Override
158162
public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) {
@@ -164,7 +168,7 @@ public void onFailure(final Exception e) {
164168
fail(e.getMessage());
165169
}
166170

167-
});
171+
}));
168172

169173
verify(client, times(2)).execute(same(XPackInfoAction.INSTANCE), any(), any());
170174
assertNotNull(licenseCheck.get());
@@ -179,15 +183,15 @@ public void testCheckRemoteClusterLicencesGivenNonExistentCluster() {
179183

180184
final List<String> remoteClusterAliases = Arrays.asList("valid1", "valid2", "valid3");
181185
final String failingClusterAlias = randomFrom(remoteClusterAliases);
182-
final Client client = createMockClientThatThrowsOnGetRemoteClusterClient(failingClusterAlias);
186+
final ThreadPool threadPool = createMockThreadPool();
187+
final Client client = createMockClientThatThrowsOnGetRemoteClusterClient(threadPool, failingClusterAlias);
183188
doAnswer(invocationMock -> {
184189
@SuppressWarnings("unchecked") ActionListener<XPackInfoResponse> listener =
185190
(ActionListener<XPackInfoResponse>) invocationMock.getArguments()[2];
186191
listener.onResponse(responses.get(index.getAndIncrement()));
187192
return null;
188193
}).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any());
189194

190-
191195
responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
192196
responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
193197
responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
@@ -196,8 +200,9 @@ public void testCheckRemoteClusterLicencesGivenNonExistentCluster() {
196200
new RemoteClusterLicenseChecker(client, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial);
197201
final AtomicReference<Exception> exception = new AtomicReference<>();
198202

199-
licenseChecker.checkRemoteClusterLicenses(remoteClusterAliases,
200-
new ActionListener<RemoteClusterLicenseChecker.LicenseCheck>() {
203+
licenseChecker.checkRemoteClusterLicenses(
204+
remoteClusterAliases,
205+
doubleInvocationProtectingListener(new ActionListener<RemoteClusterLicenseChecker.LicenseCheck>() {
201206

202207
@Override
203208
public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) {
@@ -209,7 +214,7 @@ public void onFailure(final Exception e) {
209214
exception.set(e);
210215
}
211216

212-
});
217+
}));
213218

214219
assertNotNull(exception.get());
215220
assertThat(exception.get(), instanceOf(ElasticsearchException.class));
@@ -218,6 +223,69 @@ public void onFailure(final Exception e) {
218223
assertThat(exception.get().getCause(), instanceOf(IllegalArgumentException.class));
219224
}
220225

226+
public void testListenerIsExecutedWithCallingContext() throws InterruptedException {
227+
final AtomicInteger index = new AtomicInteger();
228+
final List<XPackInfoResponse> responses = new ArrayList<>();
229+
230+
final ThreadPool threadPool = new TestThreadPool(getTestName());
231+
232+
try {
233+
final List<String> remoteClusterAliases = Arrays.asList("valid1", "valid2", "valid3");
234+
final Client client;
235+
final boolean failure = randomBoolean();
236+
if (failure) {
237+
client = createMockClientThatThrowsOnGetRemoteClusterClient(threadPool, randomFrom(remoteClusterAliases));
238+
} else {
239+
client = createMockClient(threadPool);
240+
}
241+
doAnswer(invocationMock -> {
242+
@SuppressWarnings("unchecked") ActionListener<XPackInfoResponse> listener =
243+
(ActionListener<XPackInfoResponse>) invocationMock.getArguments()[2];
244+
listener.onResponse(responses.get(index.getAndIncrement()));
245+
return null;
246+
}).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any());
247+
248+
responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
249+
responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
250+
responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
251+
252+
final RemoteClusterLicenseChecker licenseChecker =
253+
new RemoteClusterLicenseChecker(client, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial);
254+
255+
final AtomicBoolean listenerInvoked = new AtomicBoolean();
256+
threadPool.getThreadContext().putHeader("key", "value");
257+
licenseChecker.checkRemoteClusterLicenses(
258+
remoteClusterAliases,
259+
doubleInvocationProtectingListener(new ActionListener<RemoteClusterLicenseChecker.LicenseCheck>() {
260+
261+
@Override
262+
public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) {
263+
if (failure) {
264+
fail();
265+
}
266+
assertThat(threadPool.getThreadContext().getHeader("key"), equalTo("value"));
267+
assertFalse(threadPool.getThreadContext().isSystemContext());
268+
listenerInvoked.set(true);
269+
}
270+
271+
@Override
272+
public void onFailure(final Exception e) {
273+
if (failure == false) {
274+
fail();
275+
}
276+
assertThat(threadPool.getThreadContext().getHeader("key"), equalTo("value"));
277+
assertFalse(threadPool.getThreadContext().isSystemContext());
278+
listenerInvoked.set(true);
279+
}
280+
281+
}));
282+
283+
assertTrue(listenerInvoked.get());
284+
} finally {
285+
terminate(threadPool);
286+
}
287+
}
288+
221289
public void testBuildErrorMessageForActiveCompatibleLicense() {
222290
final XPackInfoResponse.LicenseInfo platinumLicence = createPlatinumLicenseResponse();
223291
final RemoteClusterLicenseChecker.RemoteClusterLicenseInfo info =
@@ -246,22 +314,52 @@ public void testBuildErrorMessageForInactiveLicense() {
246314
equalTo("the license on cluster [expired-cluster] is not active"));
247315
}
248316

249-
private Client createMockClient() {
250-
return createMockClient(client -> when(client.getRemoteClusterClient(anyString())).thenReturn(client));
317+
private ActionListener<RemoteClusterLicenseChecker.LicenseCheck> doubleInvocationProtectingListener(
318+
final ActionListener<RemoteClusterLicenseChecker.LicenseCheck> listener) {
319+
final AtomicBoolean listenerInvoked = new AtomicBoolean();
320+
return new ActionListener<RemoteClusterLicenseChecker.LicenseCheck>() {
321+
322+
@Override
323+
public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) {
324+
if (listenerInvoked.compareAndSet(false, true) == false) {
325+
fail("listener invoked twice");
326+
}
327+
listener.onResponse(response);
328+
}
329+
330+
@Override
331+
public void onFailure(final Exception e) {
332+
if (listenerInvoked.compareAndSet(false, true) == false) {
333+
fail("listener invoked twice");
334+
}
335+
listener.onFailure(e);
336+
}
337+
338+
};
339+
}
340+
341+
private ThreadPool createMockThreadPool() {
342+
final ThreadPool threadPool = mock(ThreadPool.class);
343+
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
344+
return threadPool;
251345
}
252346

253-
private Client createMockClientThatThrowsOnGetRemoteClusterClient(final String clusterAlias) {
254-
return createMockClient(client -> {
255-
when(client.getRemoteClusterClient(clusterAlias)).thenThrow(new IllegalArgumentException());
256-
when(client.getRemoteClusterClient(argThat(not(clusterAlias)))).thenReturn(client);
257-
});
347+
private Client createMockClient(final ThreadPool threadPool) {
348+
return createMockClient(threadPool, client -> when(client.getRemoteClusterClient(anyString())).thenReturn(client));
258349
}
259350

260-
private Client createMockClient(final Consumer<Client> finish) {
351+
private Client createMockClientThatThrowsOnGetRemoteClusterClient(final ThreadPool threadPool, final String clusterAlias) {
352+
return createMockClient(
353+
threadPool,
354+
client -> {
355+
when(client.getRemoteClusterClient(clusterAlias)).thenThrow(new IllegalArgumentException());
356+
when(client.getRemoteClusterClient(argThat(not(clusterAlias)))).thenReturn(client);
357+
});
358+
}
359+
360+
private Client createMockClient(final ThreadPool threadPool, final Consumer<Client> finish) {
261361
final Client client = mock(Client.class);
262-
final ThreadPool threadPool = mock(ThreadPool.class);
263362
when(client.threadPool()).thenReturn(threadPool);
264-
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
265363
finish.accept(client);
266364
return client;
267365
}

0 commit comments

Comments
 (0)