Skip to content

Commit 5f2f7f8

Browse files
committed
Pass full RestResponse to user from Extension
Signed-off-by: Daniel Widdis <widdis@gmail.com>
1 parent 50f2d9e commit 5f2f7f8

4 files changed

Lines changed: 241 additions & 22 deletions

File tree

server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionResponse.java

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,32 +10,104 @@
1010

1111
import org.opensearch.common.io.stream.StreamInput;
1212
import org.opensearch.common.io.stream.StreamOutput;
13+
import org.opensearch.rest.BytesRestResponse;
14+
import org.opensearch.rest.RestResponse;
15+
import org.opensearch.rest.RestStatus;
1316
import org.opensearch.transport.TransportResponse;
1417

1518
import java.io.IOException;
19+
import java.nio.charset.StandardCharsets;
20+
import java.util.Collections;
21+
import java.util.List;
22+
import java.util.Map;
1623

1724
/**
18-
* Response to execute REST Actions on the extension node.
25+
* Response to execute REST Actions on the extension node. Wraps the components of a {@link RestResponse}.
1926
*
2027
* @opensearch.internal
2128
*/
2229
public class RestExecuteOnExtensionResponse extends TransportResponse {
23-
private String response;
2430

25-
public RestExecuteOnExtensionResponse(String response) {
26-
this.response = response;
31+
private RestStatus status;
32+
private String contentType;
33+
private byte[] content;
34+
private Map<String, List<String>> headers;
35+
36+
/**
37+
* Instantiate this object with a status and response string.
38+
*
39+
* @param status The REST status.
40+
* @param responseString The response content as a String.
41+
*/
42+
public RestExecuteOnExtensionResponse(RestStatus status, String responseString) {
43+
this(status, BytesRestResponse.TEXT_CONTENT_TYPE, responseString.getBytes(StandardCharsets.UTF_8), Collections.emptyMap());
44+
}
45+
46+
/**
47+
* Instantiate this object with the components of a {@link RestResponse}.
48+
*
49+
* @param status The REST status.
50+
* @param contentType The type of the content.
51+
* @param content The content.
52+
* @param headers The headers.
53+
*/
54+
public RestExecuteOnExtensionResponse(RestStatus status, String contentType, byte[] content, Map<String, List<String>> headers) {
55+
setStatus(status);
56+
setContentType(contentType);
57+
setContent(content);
58+
setHeaders(headers);
2759
}
2860

61+
/**
62+
* Instantiate this object from a Transport Stream
63+
*
64+
* @param in The stream input.
65+
* @throws IOException on transport failure.
66+
*/
2967
public RestExecuteOnExtensionResponse(StreamInput in) throws IOException {
30-
response = in.readString();
68+
setStatus(RestStatus.readFrom(in));
69+
setContentType(in.readString());
70+
setContent(in.readByteArray());
71+
setHeaders(in.readMapOfLists(StreamInput::readString, StreamInput::readString));
3172
}
3273

3374
@Override
3475
public void writeTo(StreamOutput out) throws IOException {
35-
out.writeString(response);
76+
RestStatus.writeTo(out, status);
77+
out.writeString(contentType);
78+
out.writeByteArray(content);
79+
out.writeMapOfLists(headers, StreamOutput::writeString, StreamOutput::writeString);
80+
}
81+
82+
public RestStatus getStatus() {
83+
return status;
84+
}
85+
86+
public void setStatus(RestStatus status) {
87+
this.status = status;
88+
}
89+
90+
public String getContentType() {
91+
return contentType;
92+
}
93+
94+
public void setContentType(String contentType) {
95+
this.contentType = contentType;
96+
}
97+
98+
public byte[] getContent() {
99+
return content;
100+
}
101+
102+
public void setContent(byte[] content) {
103+
this.content = content;
104+
}
105+
106+
public Map<String, List<String>> getHeaders() {
107+
return headers;
36108
}
37109

38-
public String getResponse() {
39-
return response;
110+
public void setHeaders(Map<String, List<String>> headers) {
111+
this.headers = Map.copyOf(headers);
40112
}
41113
}

server/src/main/java/org/opensearch/extensions/rest/RestSendToExtensionAction.java

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import org.apache.logging.log4j.LogManager;
1212
import org.apache.logging.log4j.Logger;
13-
import org.apache.logging.log4j.message.ParameterizedMessage;
1413
import org.opensearch.client.node.NodeClient;
1514
import org.opensearch.common.io.stream.StreamInput;
1615
import org.opensearch.extensions.DiscoveryExtension;
@@ -26,11 +25,14 @@
2625
import org.opensearch.transport.TransportService;
2726

2827
import java.io.IOException;
28+
import java.nio.charset.StandardCharsets;
2929
import java.util.ArrayList;
3030
import java.util.List;
31+
import java.util.Map.Entry;
3132
import java.util.concurrent.CountDownLatch;
3233
import java.util.concurrent.TimeUnit;
3334

35+
import static java.util.Collections.emptyMap;
3436
import static java.util.Collections.unmodifiableList;
3537

3638
/**
@@ -97,8 +99,13 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC
9799
}
98100
String message = "Forwarding the request " + method + " " + uri + " to " + discoveryExtension;
99101
logger.info(message);
100-
// Hack to pass a final class in to fetch the response string
101-
final StringBuilder responseBuilder = new StringBuilder();
102+
// Initialize response. Values will be changed in the handler.
103+
final RestExecuteOnExtensionResponse restExecuteOnExtensionResponse = new RestExecuteOnExtensionResponse(
104+
RestStatus.ACCEPTED,
105+
BytesRestResponse.TEXT_CONTENT_TYPE,
106+
message.getBytes(StandardCharsets.UTF_8),
107+
emptyMap()
108+
);
102109
final CountDownLatch inProgressLatch = new CountDownLatch(1);
103110
final TransportResponseHandler<RestExecuteOnExtensionResponse> restExecuteOnExtensionResponseHandler = new TransportResponseHandler<
104111
RestExecuteOnExtensionResponse>() {
@@ -110,15 +117,20 @@ public RestExecuteOnExtensionResponse read(StreamInput in) throws IOException {
110117

111118
@Override
112119
public void handleResponse(RestExecuteOnExtensionResponse response) {
113-
responseBuilder.append(response.getResponse());
114-
logger.info("Received response from extension: {}", response.getResponse());
120+
logger.info("Received response from extension: {}", response.getStatus());
121+
restExecuteOnExtensionResponse.setStatus(response.getStatus());
122+
restExecuteOnExtensionResponse.setContentType(response.getContentType());
123+
restExecuteOnExtensionResponse.setContent(response.getContent());
124+
restExecuteOnExtensionResponse.setHeaders(response.getHeaders());
115125
inProgressLatch.countDown();
116126
}
117127

118128
@Override
119129
public void handleException(TransportException exp) {
120-
responseBuilder.append("FAILED: ").append(exp);
121-
logger.debug(new ParameterizedMessage("REST request failed"), exp);
130+
logger.debug("REST request failed", exp);
131+
restExecuteOnExtensionResponse.setStatus(RestStatus.INTERNAL_SERVER_ERROR);
132+
byte[] responseBytes = ("Request failed: " + exp.getMessage()).getBytes(StandardCharsets.UTF_8);
133+
restExecuteOnExtensionResponse.setContent(responseBytes);
122134
inProgressLatch.countDown();
123135
}
124136

@@ -144,12 +156,18 @@ public String executor() {
144156
} catch (Exception e) {
145157
logger.info("Failed to send REST Actions to extension " + discoveryExtension.getName(), e);
146158
}
147-
String response = responseBuilder.toString();
148-
if (response.isBlank() || response.startsWith("FAILED")) {
149-
return channel -> channel.sendResponse(
150-
new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, response.isBlank() ? "Request Failed" : response)
151-
);
159+
160+
BytesRestResponse restResponse = new BytesRestResponse(
161+
restExecuteOnExtensionResponse.getStatus(),
162+
restExecuteOnExtensionResponse.getContentType(),
163+
restExecuteOnExtensionResponse.getContent()
164+
);
165+
for (Entry<String, List<String>> headerEntry : restExecuteOnExtensionResponse.getHeaders().entrySet()) {
166+
for (String value : headerEntry.getValue()) {
167+
restResponse.addHeader(headerEntry.getKey(), value);
168+
}
152169
}
153-
return channel -> channel.sendResponse(new BytesRestResponse(RestStatus.OK, response));
170+
171+
return channel -> channel.sendResponse(restResponse);
154172
}
155173
}

server/src/test/java/org/opensearch/extensions/rest/RegisterRestActionsTests.java

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
package org.opensearch.extensions.rest;
1010

1111
import java.util.List;
12+
13+
import org.opensearch.common.bytes.BytesReference;
14+
import org.opensearch.common.io.stream.BytesStreamInput;
15+
import org.opensearch.common.io.stream.BytesStreamOutput;
1216
import org.opensearch.test.OpenSearchTestCase;
1317

1418
public class RegisterRestActionsTests extends OpenSearchTestCase {
@@ -17,11 +21,42 @@ public void testRegisterRestActionsRequest() throws Exception {
1721
String uniqueIdStr = "uniqueid1";
1822
List<String> expected = List.of("GET /foo", "PUT /bar", "POST /baz");
1923
RegisterRestActionsRequest registerRestActionsRequest = new RegisterRestActionsRequest(uniqueIdStr, expected);
20-
assertEquals(uniqueIdStr, registerRestActionsRequest.getUniqueId());
2124

25+
assertEquals(uniqueIdStr, registerRestActionsRequest.getUniqueId());
2226
List<String> restActions = registerRestActionsRequest.getRestActions();
2327
assertEquals(expected.size(), restActions.size());
2428
assertTrue(restActions.containsAll(expected));
2529
assertTrue(expected.containsAll(restActions));
30+
31+
try (BytesStreamOutput out = new BytesStreamOutput()) {
32+
registerRestActionsRequest.writeTo(out);
33+
out.flush();
34+
try (BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()))) {
35+
registerRestActionsRequest = new RegisterRestActionsRequest(in);
36+
37+
assertEquals(uniqueIdStr, registerRestActionsRequest.getUniqueId());
38+
restActions = registerRestActionsRequest.getRestActions();
39+
assertEquals(expected.size(), restActions.size());
40+
assertTrue(restActions.containsAll(expected));
41+
assertTrue(expected.containsAll(restActions));
42+
}
43+
}
44+
}
45+
46+
public void testRegisterRestActionsResponse() throws Exception {
47+
String response = "This is a response";
48+
RegisterRestActionsResponse registerRestActionsResponse = new RegisterRestActionsResponse(response);
49+
50+
assertEquals(response, registerRestActionsResponse.getResponse());
51+
52+
try (BytesStreamOutput out = new BytesStreamOutput()) {
53+
registerRestActionsResponse.writeTo(out);
54+
out.flush();
55+
try (BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()))) {
56+
registerRestActionsResponse = new RegisterRestActionsResponse(in);
57+
58+
assertEquals(response, registerRestActionsResponse.getResponse());
59+
}
60+
}
2661
}
2762
}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*/
8+
9+
package org.opensearch.extensions.rest;
10+
11+
import org.opensearch.rest.RestStatus;
12+
import org.opensearch.common.bytes.BytesReference;
13+
import org.opensearch.common.io.stream.BytesStreamInput;
14+
import org.opensearch.common.io.stream.BytesStreamOutput;
15+
import org.opensearch.rest.BytesRestResponse;
16+
import org.opensearch.rest.RestRequest.Method;
17+
import org.opensearch.test.OpenSearchTestCase;
18+
19+
import java.nio.charset.StandardCharsets;
20+
import java.util.List;
21+
import java.util.Map;
22+
23+
public class RestExecuteOnExtensionTests extends OpenSearchTestCase {
24+
25+
public void testRestExecuteOnExtensionRequest() throws Exception {
26+
Method expectedMethod = Method.GET;
27+
String expectedUri = "/test/uri";
28+
RestExecuteOnExtensionRequest request = new RestExecuteOnExtensionRequest(expectedMethod, expectedUri);
29+
30+
assertEquals(expectedMethod, request.getMethod());
31+
assertEquals(expectedUri, request.getUri());
32+
33+
try (BytesStreamOutput out = new BytesStreamOutput()) {
34+
request.writeTo(out);
35+
out.flush();
36+
try (BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()))) {
37+
request = new RestExecuteOnExtensionRequest(in);
38+
39+
assertEquals(expectedMethod, request.getMethod());
40+
assertEquals(expectedUri, request.getUri());
41+
}
42+
}
43+
}
44+
45+
public void testRestExecuteOnExtensionResponse() throws Exception {
46+
RestStatus expectedStatus = RestStatus.OK;
47+
String expectedContentType = BytesRestResponse.TEXT_CONTENT_TYPE;
48+
String expectedResponse = "Test response";
49+
byte[] expectedResponseBytes = expectedResponse.getBytes(StandardCharsets.UTF_8);
50+
51+
RestExecuteOnExtensionResponse response = new RestExecuteOnExtensionResponse(expectedStatus, expectedResponse);
52+
53+
assertEquals(expectedStatus, response.getStatus());
54+
assertEquals(expectedContentType, response.getContentType());
55+
assertArrayEquals(expectedResponseBytes, response.getContent());
56+
assertEquals(0, response.getHeaders().size());
57+
58+
String headerKey = "foo";
59+
List<String> headerValueList = List.of("bar", "baz");
60+
Map<String, List<String>> expectedHeaders = Map.of(headerKey, headerValueList);
61+
62+
response = new RestExecuteOnExtensionResponse(expectedStatus, expectedContentType, expectedResponseBytes, expectedHeaders);
63+
64+
assertEquals(expectedStatus, response.getStatus());
65+
assertEquals(expectedContentType, response.getContentType());
66+
assertArrayEquals(expectedResponseBytes, response.getContent());
67+
68+
assertEquals(1, expectedHeaders.keySet().size());
69+
assertTrue(expectedHeaders.containsKey(headerKey));
70+
71+
List<String> fooList = expectedHeaders.get(headerKey);
72+
assertEquals(2, fooList.size());
73+
assertTrue(fooList.containsAll(headerValueList));
74+
75+
try (BytesStreamOutput out = new BytesStreamOutput()) {
76+
response.writeTo(out);
77+
out.flush();
78+
try (BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()))) {
79+
response = new RestExecuteOnExtensionResponse(in);
80+
81+
assertEquals(expectedStatus, response.getStatus());
82+
assertEquals(expectedContentType, response.getContentType());
83+
assertArrayEquals(expectedResponseBytes, response.getContent());
84+
85+
assertEquals(1, expectedHeaders.keySet().size());
86+
assertTrue(expectedHeaders.containsKey(headerKey));
87+
88+
fooList = expectedHeaders.get(headerKey);
89+
assertEquals(2, fooList.size());
90+
assertTrue(fooList.containsAll(headerValueList));
91+
}
92+
}
93+
}
94+
}

0 commit comments

Comments
 (0)