Skip to content

Commit 54e551e

Browse files
committed
feat(RequestTracing): add capture streams and middleware for enhanced request/response tracing
1 parent 30b7917 commit 54e551e

10 files changed

Lines changed: 613 additions & 157 deletions

File tree

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
using Chats.BE.Services.RequestTracing;
2+
using System.Text;
3+
4+
namespace Chats.BE.UnitTest.Services.RequestTracing;
5+
6+
public class CaptureStreamsTests
7+
{
8+
[Fact]
9+
public void WriteCaptureStream_ShouldCapturePrefixAndMarkTruncated()
10+
{
11+
byte[] payload = Encoding.UTF8.GetBytes("abcdef");
12+
using MemoryStream inner = new();
13+
using WriteCaptureStream tee = new(inner, 4);
14+
15+
tee.Write(payload, 0, payload.Length);
16+
17+
Assert.Equal(payload.Length, inner.ToArray().Length);
18+
Assert.Equal(Encoding.UTF8.GetBytes("abcd"), tee.CapturedBytes);
19+
Assert.True(tee.IsTruncated);
20+
}
21+
22+
[Fact]
23+
public void WriteCaptureStream_ShouldUseDefaultLimitWhenNotProvided()
24+
{
25+
byte[] payload = Encoding.UTF8.GetBytes("hello");
26+
using MemoryStream inner = new();
27+
using WriteCaptureStream tee = new(inner);
28+
29+
tee.Write(payload, 0, payload.Length);
30+
31+
Assert.Equal(payload, tee.CapturedBytes);
32+
Assert.False(tee.IsTruncated);
33+
}
34+
35+
[Fact]
36+
public void ReadCaptureStream_ShouldCaptureReadPrefixAndMarkTruncated()
37+
{
38+
byte[] payload = Encoding.UTF8.GetBytes("abcdef");
39+
using MemoryStream inner = new(payload);
40+
41+
int totalBytesRead = 0;
42+
byte[]? capturedBytes = null;
43+
bool truncated = false;
44+
45+
using ReadCaptureStream capture = new(
46+
inner,
47+
4,
48+
(total, captured, isTruncated) =>
49+
{
50+
totalBytesRead = total;
51+
capturedBytes = captured;
52+
truncated = isTruncated;
53+
});
54+
55+
byte[] buffer = new byte[32];
56+
while (capture.Read(buffer, 0, buffer.Length) > 0)
57+
{
58+
}
59+
60+
Assert.Equal(payload.Length, totalBytesRead);
61+
Assert.Equal(Encoding.UTF8.GetBytes("abcd"), capturedBytes);
62+
Assert.True(truncated);
63+
}
64+
}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
using Chats.BE.Services.RequestTracing;
2+
using System.Text;
3+
4+
namespace Chats.BE.UnitTest.Services.RequestTracing;
5+
6+
public class ObservedHttpContentTests
7+
{
8+
[Fact]
9+
public async Task ObservedRequestHttpContent_ShouldInvokeCallback_WhenSerialized()
10+
{
11+
byte[] payload = Encoding.UTF8.GetBytes("abcdef");
12+
using ByteArrayContent inner = new(payload);
13+
14+
int totalBytes = -1;
15+
byte[]? capturedBytes = null;
16+
bool truncated = false;
17+
18+
using ObservedRequestHttpContent observed = new(
19+
inner,
20+
4,
21+
(total, captured, isTruncated) =>
22+
{
23+
totalBytes = total;
24+
capturedBytes = captured;
25+
truncated = isTruncated;
26+
});
27+
28+
using MemoryStream sink = new();
29+
await observed.CopyToAsync(sink);
30+
31+
Assert.Equal(payload.Length, totalBytes);
32+
Assert.Equal(Encoding.UTF8.GetBytes("abcd"), capturedBytes);
33+
Assert.True(truncated);
34+
}
35+
36+
[Fact]
37+
public async Task ObservedResponseHttpContent_ShouldInvokeCallback_WhenReadToEnd()
38+
{
39+
byte[] payload = Encoding.UTF8.GetBytes("abcdef");
40+
using ByteArrayContent inner = new(payload);
41+
42+
int totalBytes = -1;
43+
byte[]? capturedBytes = null;
44+
bool truncated = false;
45+
46+
using ObservedResponseHttpContent observed = new(
47+
inner,
48+
4,
49+
(total, captured, isTruncated) =>
50+
{
51+
totalBytes = total;
52+
capturedBytes = captured;
53+
truncated = isTruncated;
54+
});
55+
56+
await using Stream stream = await observed.ReadAsStreamAsync();
57+
byte[] buffer = new byte[16];
58+
while (await stream.ReadAsync(buffer, 0, buffer.Length) > 0)
59+
{
60+
}
61+
62+
Assert.Equal(payload.Length, totalBytes);
63+
Assert.Equal(Encoding.UTF8.GetBytes("abcd"), capturedBytes);
64+
Assert.True(truncated);
65+
}
66+
67+
[Fact]
68+
public void ObservedResponseHttpContent_ShouldNotInvokeCallback_WhenNotConsumed()
69+
{
70+
byte[] payload = Encoding.UTF8.GetBytes("abcdef");
71+
using ByteArrayContent inner = new(payload);
72+
73+
bool invoked = false;
74+
using ObservedResponseHttpContent observed = new(
75+
inner,
76+
4,
77+
(_, _, _) => invoked = true);
78+
79+
Assert.False(invoked);
80+
}
81+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using Chats.BE.Services.RequestTracing;
2+
3+
namespace Chats.BE.UnitTest.Services.RequestTracing;
4+
5+
public class RequestTraceHelperTests
6+
{
7+
[Fact]
8+
public void ResolveRawCaptureLimit_ShouldUseDefaultWhenMissingOrInvalid()
9+
{
10+
Assert.Equal(RequestTraceHelper.DefaultRawCaptureMaxBytes, RequestTraceHelper.ResolveRawCaptureLimit(null));
11+
Assert.Equal(RequestTraceHelper.DefaultRawCaptureMaxBytes, RequestTraceHelper.ResolveRawCaptureLimit(0));
12+
Assert.Equal(RequestTraceHelper.DefaultRawCaptureMaxBytes, RequestTraceHelper.ResolveRawCaptureLimit(-1));
13+
}
14+
15+
[Fact]
16+
public void ResolveRawCaptureLimit_ShouldUseConfiguredWhenPositive()
17+
{
18+
Assert.Equal(12345, RequestTraceHelper.ResolveRawCaptureLimit(12345));
19+
}
20+
21+
[Fact]
22+
public void IsSmallKnownLength_ShouldRespectKnownLengthAndFloorCap()
23+
{
24+
Assert.False(RequestTraceHelper.IsSmallKnownLength(null, 1024));
25+
Assert.False(RequestTraceHelper.IsSmallKnownLength(-1, 1024));
26+
27+
Assert.True(RequestTraceHelper.IsSmallKnownLength(200 * 1024, 100));
28+
Assert.False(RequestTraceHelper.IsSmallKnownLength(300 * 1024, 100));
29+
30+
Assert.True(RequestTraceHelper.IsSmallKnownLength(2 * 1024 * 1024, 3 * 1024 * 1024));
31+
Assert.False(RequestTraceHelper.IsSmallKnownLength(4 * 1024 * 1024, 3 * 1024 * 1024));
32+
}
33+
}

src/BE/web/Infrastructure/RequestTraceMiddleware.cs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,11 @@ public async Task Invoke(HttpContext context)
6161
bool captureRequestBody = config.Body.CaptureRequestBody || config.Body.CaptureRawRequestBody;
6262
if (captureRequestBody && context.Request.Body != Stream.Null && context.Request.Body.CanRead)
6363
{
64+
int rawCaptureLimit = RequestTraceHelper.ResolveRawCaptureLimit(null);
6465
Stream originalRequestBody = context.Request.Body;
65-
context.Request.Body = new RequestReadCaptureStream(
66+
context.Request.Body = new ReadCaptureStream(
6667
originalRequestBody,
67-
config.Body.MaxTextCharsForTruncate,
68+
rawCaptureLimit,
6869
(totalBytesRead, capturedBytes, truncated) =>
6970
{
7071
try
@@ -108,10 +109,10 @@ public async Task Invoke(HttpContext context)
108109
}
109110

110111
Stream originalResponseBody = context.Response.Body;
111-
TeeCaptureStream? tee = null;
112+
WriteCaptureStream? tee = null;
112113
if (config.Body.CaptureResponseBody || config.Body.CaptureRawResponseBody)
113114
{
114-
tee = new TeeCaptureStream(originalResponseBody, config.Body.MaxTextCharsForTruncate);
115+
tee = new WriteCaptureStream(originalResponseBody);
115116
context.Response.Body = tee;
116117
}
117118

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
using System.Net;
2+
using System.Net.Http.Headers;
3+
4+
namespace Chats.BE.Services.RequestTracing;
5+
6+
internal sealed class ObservedRequestHttpContent : HttpContent
7+
{
8+
private readonly HttpContent _inner;
9+
private readonly int _maxCaptureBytes;
10+
private readonly Action<int, byte[], bool> _onCompleted;
11+
private int _completedFlag;
12+
13+
public ObservedRequestHttpContent(HttpContent inner, int maxCaptureBytes, Action<int, byte[], bool> onCompleted)
14+
{
15+
_inner = inner;
16+
_maxCaptureBytes = maxCaptureBytes;
17+
_onCompleted = onCompleted;
18+
CopyHeaders(_inner.Headers, Headers);
19+
}
20+
21+
protected override async Task SerializeToStreamAsync(Stream stream, TransportContext? context)
22+
{
23+
await SerializeToStreamAsync(stream, context, CancellationToken.None);
24+
}
25+
26+
protected override async Task SerializeToStreamAsync(Stream stream, TransportContext? context, CancellationToken cancellationToken)
27+
{
28+
using WriteCaptureStream captureStream = new(stream, _maxCaptureBytes);
29+
try
30+
{
31+
await _inner.CopyToAsync(captureStream, context, cancellationToken);
32+
CompleteOnce(captureStream.TotalBytesWritten, captureStream.CapturedBytes, captureStream.IsTruncated);
33+
}
34+
catch
35+
{
36+
CompleteOnce(captureStream.TotalBytesWritten, captureStream.CapturedBytes, true);
37+
throw;
38+
}
39+
}
40+
41+
protected override bool TryComputeLength(out long length)
42+
{
43+
if (_inner.Headers.ContentLength.HasValue)
44+
{
45+
length = _inner.Headers.ContentLength.Value;
46+
return true;
47+
}
48+
49+
length = 0;
50+
return false;
51+
}
52+
53+
protected override void Dispose(bool disposing)
54+
{
55+
if (disposing)
56+
{
57+
_inner.Dispose();
58+
}
59+
60+
base.Dispose(disposing);
61+
}
62+
63+
private void CompleteOnce(int totalBytes, byte[] capturedBytes, bool truncated)
64+
{
65+
if (Interlocked.Exchange(ref _completedFlag, 1) != 0)
66+
{
67+
return;
68+
}
69+
70+
_onCompleted(totalBytes, capturedBytes, truncated);
71+
}
72+
73+
protected override Task<Stream> CreateContentReadStreamAsync()
74+
{
75+
return _inner.ReadAsStreamAsync();
76+
}
77+
78+
protected override Stream CreateContentReadStream(CancellationToken cancellationToken)
79+
{
80+
return _inner.ReadAsStream(cancellationToken);
81+
}
82+
83+
private static void CopyHeaders(HttpContentHeaders source, HttpContentHeaders target)
84+
{
85+
foreach (KeyValuePair<string, IEnumerable<string>> header in source)
86+
{
87+
target.TryAddWithoutValidation(header.Key, header.Value);
88+
}
89+
}
90+
}
91+
92+
internal sealed class ObservedResponseHttpContent : HttpContent
93+
{
94+
private readonly HttpContent _inner;
95+
private readonly int _maxCaptureBytes;
96+
private readonly Action<int, byte[], bool> _onCompleted;
97+
private int _completedFlag;
98+
99+
public ObservedResponseHttpContent(HttpContent inner, int maxCaptureBytes, Action<int, byte[], bool> onCompleted)
100+
{
101+
_inner = inner;
102+
_maxCaptureBytes = maxCaptureBytes;
103+
_onCompleted = onCompleted;
104+
CopyHeaders(_inner.Headers, Headers);
105+
}
106+
107+
protected override async Task SerializeToStreamAsync(Stream stream, TransportContext? context)
108+
{
109+
await SerializeToStreamAsync(stream, context, CancellationToken.None);
110+
}
111+
112+
protected override async Task SerializeToStreamAsync(Stream stream, TransportContext? context, CancellationToken cancellationToken)
113+
{
114+
await using Stream source = await _inner.ReadAsStreamAsync(cancellationToken);
115+
using ReadCaptureStream captureStream = new(
116+
source,
117+
_maxCaptureBytes,
118+
CompleteOnce);
119+
120+
byte[] buffer = new byte[81920];
121+
while (true)
122+
{
123+
int read = await captureStream.ReadAsync(buffer.AsMemory(0, buffer.Length), cancellationToken);
124+
if (read <= 0)
125+
{
126+
break;
127+
}
128+
129+
await stream.WriteAsync(buffer.AsMemory(0, read), cancellationToken);
130+
}
131+
}
132+
133+
protected override bool TryComputeLength(out long length)
134+
{
135+
if (_inner.Headers.ContentLength.HasValue)
136+
{
137+
length = _inner.Headers.ContentLength.Value;
138+
return true;
139+
}
140+
141+
length = 0;
142+
return false;
143+
}
144+
145+
protected override void Dispose(bool disposing)
146+
{
147+
if (disposing)
148+
{
149+
_inner.Dispose();
150+
}
151+
152+
base.Dispose(disposing);
153+
}
154+
155+
protected override async Task<Stream> CreateContentReadStreamAsync()
156+
{
157+
Stream source = await _inner.ReadAsStreamAsync();
158+
return new ReadCaptureStream(source, _maxCaptureBytes, CompleteOnce);
159+
}
160+
161+
protected override Stream CreateContentReadStream(CancellationToken cancellationToken)
162+
{
163+
Stream source = _inner.ReadAsStream(cancellationToken);
164+
return new ReadCaptureStream(source, _maxCaptureBytes, CompleteOnce);
165+
}
166+
167+
private void CompleteOnce(int totalBytes, byte[] capturedBytes, bool truncated)
168+
{
169+
if (Interlocked.Exchange(ref _completedFlag, 1) != 0)
170+
{
171+
return;
172+
}
173+
174+
_onCompleted(totalBytes, capturedBytes, truncated);
175+
}
176+
177+
178+
private static void CopyHeaders(HttpContentHeaders source, HttpContentHeaders target)
179+
{
180+
foreach (KeyValuePair<string, IEnumerable<string>> header in source)
181+
{
182+
target.TryAddWithoutValidation(header.Key, header.Value);
183+
}
184+
}
185+
}

0 commit comments

Comments
 (0)