Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,28 @@ public StoredContext stashAndMergeHeaders(Map<String, String> headers) {
return () -> threadLocal.set(context);
}

/**
* Removes the current context and resets a new context that is a copy of the current one except that the request
* headers do not contain the given headers to remove. The removed context can be restored when closing the returned
* {@link StoredContext}.
* @param headersToRemove the request headers to remove
*/
public StoredContext removeRequestHeaders(Set<String> headersToRemove) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'd call this stashWithoutHeader()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the other stashXXX() methods start off with the default thread context, I avoided prefixing this one with stash. I don't feel strongly on it though. What do you reckon @DaveCTurner ?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, feel free to merge like it is so that we don't waste time on naming :)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I also don't want to say stash. I didn't come up with an obviously better name, although maybe I'd have called it something like withoutRequestHeaders or removingRequestHeaders rather than using the imperative remove. NBD anyway, naming is hard.

final ThreadContextStruct context = threadLocal.get();
Map<String, String> newRequestHeaders = new HashMap<>(context.requestHeaders);
newRequestHeaders.keySet().removeAll(headersToRemove);
threadLocal.set(
new ThreadContextStruct(
newRequestHeaders,
context.responseHeaders,
context.transientHeaders,
context.isSystemContext,
context.warningHeadersSize
)
);
return () -> threadLocal.set(context);
}

/**
* Just like {@link #stashContext()} but no default context is set.
* @param preserveResponseHeaders if set to <code>true</code> the response headers of the restore thread will be preserved.
Expand Down
6 changes: 6 additions & 0 deletions server/src/main/java/org/elasticsearch/tasks/Task.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,15 @@ public class Task implements TracingPlugin.Traceable {
* The request header which is contained in HTTP request. We parse trace.id from it and store it in thread context.
* TRACE_PARENT once parsed in RestController.tryAllHandler is not preserved
* has to be declared as a header copied over from http request.
* May also be used internally when apm plugin is enabled.
*/
public static final String TRACE_PARENT = "traceparent";

/**
* Is used internally to pass the apm trace context between the nodes
*/
public static final String TRACE_STATE = "tracestate";

/**
* Parsed part of traceparent. It is stored in thread context and emitted in logs.
* Has to be declared as a header copied over for tasks.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;

import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -179,6 +180,30 @@ public void testStashAndMerge() {
assertEquals("1", threadContext.getHeader("default"));
}

public void testRemoveHeaders() {
Settings build = Settings.builder().put("request.headers.default", "1").build();
ThreadContext threadContext = new ThreadContext(build);
threadContext.putHeader("h_1", "h_1_value");
threadContext.putHeader("h_2", "h_2_value");
threadContext.putHeader("h_3", "h_3_value");

threadContext.putTransient("ctx.transient_1", 1);
threadContext.addResponseHeader("resp.header", "baaaam");
try (ThreadContext.StoredContext ctx = threadContext.removeRequestHeaders(Set.of("h_1", "h_3"))) {
assertThat(threadContext.getHeaders(), equalTo(Map.of("default", "1", "h_2", "h_2_value")));
assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.transient_1"));
assertEquals("1", threadContext.getHeader("default"));
assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size());
assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(0));
}

assertThat(threadContext.getHeaders(), equalTo(Map.of("default", "1", "h_1", "h_1_value", "h_2", "h_2_value", "h_3", "h_3_value")));
assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.transient_1"));
assertEquals("1", threadContext.getHeader("default"));
assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size());
assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(0));
}

public void testStoreContext() {
Settings build = Settings.builder().put("request.headers.default", "1").build();
ThreadContext threadContext = new ThreadContext(build);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
import org.elasticsearch.tasks.TaskTracer;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.transport.TransportService;
import org.junit.After;

import java.util.Collection;
import java.util.Collections;
import java.util.List;

import static java.util.stream.Collectors.toList;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.notNullValue;

Expand All @@ -44,6 +47,11 @@ protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
return Settings.builder().put(super.nodeSettings(nodeOrdinal, otherSettings)).setSecureSettings(secureSettings).build();
}

@After
public void clearRecordedSpans() {
APMTracer.CAPTURING_SPAN_EXPORTER.clear();
}

public void testModule() {
List<TracingPlugin> plugins = internalCluster().getMasterNodeInstance(PluginsService.class).filterPlugins(TracingPlugin.class);
assertThat(plugins, hasSize(1));
Expand All @@ -68,4 +76,23 @@ public void testModule() {
}
assertTrue(found);
}

public void testRecordsNestedSpans() {

APMTracer.CAPTURING_SPAN_EXPORTER.clear();// removing start related events

client().admin().cluster().prepareListTasks().get();

var parentTasks = APMTracer.CAPTURING_SPAN_EXPORTER.findSpanByName("cluster:monitor/tasks/lists").collect(toList());
assertThat(parentTasks, hasSize(1));
var parentTask = parentTasks.get(0);
assertThat(parentTask.getParentSpanId(), equalTo("0000000000000000"));

var childrenTasks = APMTracer.CAPTURING_SPAN_EXPORTER.findSpanByName("cluster:monitor/tasks/lists[n]").collect(toList());
assertThat(childrenTasks, hasSize(internalCluster().size()));
for (SpanData childrenTask : childrenTasks) {
assertThat(childrenTask.getParentSpanId(), equalTo(parentTask.getSpanId()));
assertThat(childrenTask.getTraceId(), equalTo(parentTask.getTraceId()));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,35 @@
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.env.Environment;
import org.elasticsearch.env.NodeEnvironment;
import org.elasticsearch.plugins.NetworkPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.TracingPlugin;
import org.elasticsearch.repositories.RepositoriesService;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportInterceptor;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.watcher.ResourceWatcherService;
import org.elasticsearch.xcontent.NamedXContentRegistry;

import java.util.Collection;
import java.util.List;
import java.util.Set;
import java.util.function.Supplier;

public class APM extends Plugin implements TracingPlugin {
public class APM extends Plugin implements TracingPlugin, NetworkPlugin {

private final SetOnce<Tracer> tracer = new SetOnce<>();
public static final Set<String> TRACE_HEADERS = Set.of(Task.TRACE_PARENT, Task.TRACE_STATE);

private final SetOnce<APMTracer> tracer = new SetOnce<>();
private final Settings settings;

public APM(Settings settings) {
Expand All @@ -51,12 +63,55 @@ public Collection<Object> createComponents(
IndexNameExpressionResolver indexNameExpressionResolver,
Supplier<RepositoriesService> repositoriesServiceSupplier
) {
tracer.set(new APMTracer(settings, clusterService));
tracer.set(new APMTracer(settings, threadPool, clusterService));
return List.of(tracer.get());
}

@Override
public List<Setting<?>> getSettings() {
return List.of(APMTracer.APM_ENDPOINT_SETTING, APMTracer.APM_TOKEN_SETTING);
}

public List<TransportInterceptor> getTransportInterceptors(NamedWriteableRegistry namedWriteableRegistry, ThreadContext threadContext) {
return List.of(new TransportInterceptor() {
@Override
public AsyncSender interceptSender(AsyncSender sender) {
return new ApmTransportInterceptor(sender, threadContext);
}
});
}

private class ApmTransportInterceptor implements TransportInterceptor.AsyncSender {

private final TransportInterceptor.AsyncSender sender;
private final ThreadContext threadContext;

ApmTransportInterceptor(TransportInterceptor.AsyncSender sender, ThreadContext threadContext) {
this.sender = sender;
this.threadContext = threadContext;
}

@Override
public <T extends TransportResponse> void sendRequest(
Transport.Connection connection,
String action,
TransportRequest request,
TransportRequestOptions options,
TransportResponseHandler<T> handler
) {
if (tracer.get() == null) {
sender.sendRequest(connection, action, request, options, handler);
} else {
var headers = tracer.get().getSpanHeadersById(String.valueOf(request.getParentTask().getId()));
if (headers != null) {
try (var ignore = threadContext.removeRequestHeaders(TRACE_HEADERS)) {
threadContext.putHeader(headers);
sender.sendRequest(connection, action, request, options, handler);
}
} else {
sender.sendRequest(connection, action, request, options, handler);
}
}
}
}
}
Loading