Skip to content

Commit f134870

Browse files
committed
👕 fixed coderabbitai suggestions
1 parent dd05e19 commit f134870

4 files changed

Lines changed: 75 additions & 2 deletions

File tree

components/core/src/clp/CurlDownloadHandler.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include <cstddef>
55
#include <map>
66
#include <memory>
7+
#include <regex>
8+
#include <set>
79
#include <string>
810
#include <string_view>
911
#include <utility>
@@ -58,9 +60,25 @@ CurlDownloadHandler::CurlDownloadHandler(
5860
m_http_headers.append("Cache-Control: no-cache");
5961
m_http_headers.append("Pragma: no-cache");
6062
}
63+
static const std::set<std::string> reserved_headers = {
64+
"range",
65+
"cache-control",
66+
"pragma"
67+
};
6168
for (const auto& [key, value] : custom_headers) {
62-
if ("Range" != key && "Cache-Control" != key && "Pragma" != key) {
63-
m_http_headers.append(key + ": " + value);
69+
// Convert to lowercase for case-insensitive comparison
70+
std::string lower_key = key;
71+
std::transform(lower_key.begin(), lower_key.end(), lower_key.begin(), ::tolower);
72+
73+
if (reserved_headers.end() == reserved_headers.find(lower_key)) {
74+
// Filter out illegal header names and header values by regex
75+
// Can contain alphanumeric characters (A-Z, a-z, 0-9), hyphens (`-`), and underscores (`_`)
76+
std::regex header_name_pattern("^[A-Za-z0-9_-]+$");
77+
// Must consist of printable ASCII characters (values between 0x20 and 0x7E)
78+
std::regex header_value_pattern("^[\\x20-\\x7E]*$");
79+
if (std::regex_match(key, header_name_pattern) && std::regex_match(value, header_value_pattern)) {
80+
m_http_headers.append(key + ": " + value);
81+
}
6482
}
6583
}
6684
if (false == m_http_headers.is_empty()) {

components/core/src/clp/CurlDownloadHandler.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class CurlDownloadHandler {
5454
* Doc: https://curl.se/libcurl/c/CURLOPT_CONNECTTIMEOUT.html
5555
* @param overall_timeout Maximum time that the transfer may take. Note that this includes
5656
* `connection_timeout`. Doc: https://curl.se/libcurl/c/CURLOPT_TIMEOUT.html
57+
* @param custom_headers Custom request headers passed by users.
5758
*/
5859
explicit CurlDownloadHandler(
5960
std::shared_ptr<ErrorMsgBuf> error_msg_buf,

components/core/src/clp/NetworkReader.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class NetworkReader : public ReaderInterface {
9595
* Doc: https://curl.se/libcurl/c/CURLOPT_CONNECTTIMEOUT.html
9696
* @param buffer_pool_size The required number of buffers in the buffer pool.
9797
* @param buffer_size The size of each buffer in the buffer pool.
98+
* @param custom_headers Custom request headers passed by users.
9899
*/
99100
explicit NetworkReader(
100101
std::string_view src_url,
@@ -244,6 +245,7 @@ class NetworkReader : public ReaderInterface {
244245
* @param reader
245246
* @param offset Index of the byte at which to start the download.
246247
* @param disable_caching Whether to disable caching.
248+
* @param custom_headers Custom request headers passed by users.
247249
*/
248250
DownloaderThread(NetworkReader& reader,
249251
size_t offset,

components/core/tests/test-NetworkReader.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <cstddef>
33
#include <cstdint>
44
#include <filesystem>
5+
#include <map>
56
#include <memory>
67
#include <string>
78
#include <string_view>
@@ -188,3 +189,54 @@ TEST_CASE("network_reader_illegal_offset", "[NetworkReader]") {
188189
size_t pos{};
189190
REQUIRE((clp::ErrorCode_Failure == reader.try_get_pos(pos)));
190191
}
192+
193+
TEST_CASE("network_reader_with_custom_headers", "[NetworkReader]") {
194+
std::map<std::string, std::string> custom_headers = std::map<std::string, std::string>();
195+
// We use httpbin (https://httpbin.org/) to test the custom headers. This request will return a
196+
// JSON object that contains the custom headers. We check if the headers are in the response.
197+
const int NR_HEADERS = 10;
198+
for (int i = 0; i < NR_HEADERS; i++) {
199+
std::string key = "Unit-Test-Key" + std::to_string(i);
200+
std::string value = "Unit-Test-Value" + std::to_string(i);
201+
custom_headers[key] = value;
202+
}
203+
// The following three headers are determined by offset and disable_cache, which should not be
204+
// overrided by custom headers.
205+
custom_headers["Range"] = "bytes=100-";
206+
custom_headers["Cache-Control"] = "no-cache";
207+
custom_headers["Pragma"] = "no-cache";
208+
// Some illegal custom header names and values, which should not be added into headers
209+
custom_headers["A Space"] = "xx";
210+
custom_headers["A\nNewline"] = "xx";
211+
custom_headers["An@At"] = "xx";
212+
custom_headers["-Start-with-Non-Alphanumeric"] = "xx";
213+
custom_headers["Legal-Name1"] = "newline\n";
214+
custom_headers["Legal-Name2"] = "control-char\x01";
215+
clp::NetworkReader reader{
216+
"https://httpbin.org/headers",
217+
0,
218+
false,
219+
clp::CurlDownloadHandler::cDefaultOverallTimeout,
220+
clp::CurlDownloadHandler::cDefaultConnectionTimeout,
221+
clp::NetworkReader::cDefaultBufferPoolSize,
222+
clp::NetworkReader::cDefaultBufferSize,
223+
custom_headers
224+
};
225+
auto const actual{get_content(reader)};
226+
std::string actual_string(actual.begin(), actual.end());
227+
REQUIRE(assert_curl_error_code(CURLE_OK, reader));
228+
for (int i = 0; i < NR_HEADERS; i++) {
229+
std::string field = "\"Unit-Test-Key" + std::to_string(i) + "\": \"Unit-Test-Value"
230+
+ std::to_string(i) + "\"";
231+
REQUIRE(std::string::npos != actual_string.find(field));
232+
}
233+
REQUIRE(std::string::npos == actual_string.find("Range: bytes=100-"));
234+
REQUIRE(std::string::npos == actual_string.find("Cache-Control: no-cache"));
235+
REQUIRE(std::string::npos == actual_string.find("Pragma: no-cache"));
236+
REQUIRE(std::string::npos == actual_string.find("A Space:"));
237+
REQUIRE(std::string::npos == actual_string.find("A\nNewline:"));
238+
REQUIRE(std::string::npos == actual_string.find("An@At:"));
239+
REQUIRE(std::string::npos == actual_string.find("-Start-with-Non-Alphanumeric:"));
240+
REQUIRE(std::string::npos == actual_string.find("Legal-Name1:"));
241+
REQUIRE(std::string::npos == actual_string.find("Legal-Name2:"));
242+
}

0 commit comments

Comments
 (0)