Skip to content

Commit d3c1220

Browse files
[ACTP] Fix memory DoS via unbounded script output buffering (#46920)
## Summary Enforce the 10MB script output limit during execution, not after, preventing memory DoS from unbounded buffering. ## Description cmd.Stdout and cmd.Stderr used raw bytes.Buffer, so a script could write gigabytes into process memory before cmd.Run() returned and the size check ran. The fix introduces a limitedWriter (shared counter across stdout/stderr) that returns an error mid-write once the limit is hit, breaking the pipe to the child process. The experimental shell script runner had no size limit at all. ## Changes - limited_writer.go — new limitedWriter type with shared byte counter and newLimitedWriterPair() constructor - limited_writer_test.go — 6 unit tests covering under/at/over limit, sticky failure, and shared counter behavior - run_predefined_script.go — replace bytes.Buffer with limitedWriter pair - run_predefined_powershell_script.go — same - run_shell_script_experimental.go — same (added limit enforcement where none existed) Co-authored-by: irving.santiago <irving.santiago@datadoghq.com>
1 parent f8d9965 commit d3c1220

5 files changed

Lines changed: 257 additions & 29 deletions

File tree

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Unless explicitly stated otherwise all files in this repository are licensed
2+
// under the Apache License Version 2.0.
3+
// This product includes software developed at Datadog (https://www.datadoghq.com/).
4+
// Copyright 2025-present Datadog, Inc.
5+
6+
package com_datadoghq_script
7+
8+
import (
9+
"bytes"
10+
"errors"
11+
"fmt"
12+
"sync/atomic"
13+
)
14+
15+
const defaultMaxOutputSize = 10 * 1024 * 1024 // 10MB
16+
17+
// errOutputLimitExceeded is a sentinel used for errors.Is matching.
18+
var errOutputLimitExceeded = errors.New("script output limit exceeded")
19+
20+
func newOutputLimitError(limit int64) error {
21+
return fmt.Errorf("script output exceeded %dMB limit: %w", limit/(1024*1024), errOutputLimitExceeded)
22+
}
23+
24+
// limitedWriter wraps a bytes.Buffer and enforces a shared byte limit across
25+
// one or more writers. Once the combined written bytes reach the limit,
26+
// subsequent writes return errOutputLimitExceeded, which causes the OS to
27+
// deliver a broken-pipe signal to the child process.
28+
type limitedWriter struct {
29+
buf bytes.Buffer
30+
shared *atomic.Int64 // shared counter across stdout+stderr writers
31+
limit int64
32+
limited bool // sticky flag: once true, all further writes fail
33+
}
34+
35+
// newLimitedStdoutStderrWritersPair creates two limitedWriters that share the same atomic
36+
// byte counter, so the combined output of stdout and stderr is bounded by limit.
37+
func newLimitedStdoutStderrWritersPair(limit int64) (*limitedWriter, *limitedWriter) {
38+
shared := &atomic.Int64{}
39+
return &limitedWriter{shared: shared, limit: limit},
40+
&limitedWriter{shared: shared, limit: limit}
41+
}
42+
43+
func (lw *limitedWriter) Write(p []byte) (int, error) {
44+
if lw.limited {
45+
return 0, errOutputLimitExceeded
46+
}
47+
48+
remaining := lw.limit - lw.shared.Load()
49+
if remaining <= 0 {
50+
lw.limited = true
51+
return 0, errOutputLimitExceeded
52+
}
53+
54+
toWrite := p
55+
if int64(len(p)) > remaining {
56+
toWrite = p[:remaining]
57+
lw.limited = true
58+
}
59+
60+
n, err := lw.buf.Write(toWrite)
61+
lw.shared.Add(int64(n))
62+
if err != nil {
63+
return n, err
64+
}
65+
66+
if lw.limited {
67+
return n, errOutputLimitExceeded
68+
}
69+
return n, nil
70+
}
71+
72+
func (lw *limitedWriter) String() string {
73+
return lw.buf.String()
74+
}
75+
76+
func (lw *limitedWriter) Len() int {
77+
return lw.buf.Len()
78+
}
79+
80+
// LimitReached returns true if the combined output limit was exceeded.
81+
func (lw *limitedWriter) LimitReached() bool {
82+
return lw.limited
83+
}
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
// Unless explicitly stated otherwise all files in this repository are licensed
2+
// under the Apache License Version 2.0.
3+
// This product includes software developed at Datadog (https://www.datadoghq.com/).
4+
// Copyright 2025-present Datadog, Inc.
5+
6+
package com_datadoghq_script
7+
8+
import (
9+
"sync"
10+
"testing"
11+
12+
"github.com/stretchr/testify/assert"
13+
"github.com/stretchr/testify/require"
14+
)
15+
16+
func TestLimitedWriter_UnderLimit(t *testing.T) {
17+
stdout, stderr := newLimitedStdoutStderrWritersPair(100)
18+
19+
n, err := stdout.Write([]byte("hello"))
20+
require.NoError(t, err)
21+
assert.Equal(t, 5, n)
22+
23+
n, err = stderr.Write([]byte("world"))
24+
require.NoError(t, err)
25+
assert.Equal(t, 5, n)
26+
27+
assert.Equal(t, "hello", stdout.String())
28+
assert.Equal(t, "world", stderr.String())
29+
assert.False(t, stdout.LimitReached())
30+
assert.False(t, stderr.LimitReached())
31+
}
32+
33+
func TestLimitedWriter_ExactLimit(t *testing.T) {
34+
stdout, stderr := newLimitedStdoutStderrWritersPair(10)
35+
36+
n, err := stdout.Write([]byte("12345"))
37+
require.NoError(t, err)
38+
assert.Equal(t, 5, n)
39+
40+
n, err = stderr.Write([]byte("67890"))
41+
require.NoError(t, err)
42+
assert.Equal(t, 5, n)
43+
44+
assert.Equal(t, "12345", stdout.String())
45+
assert.Equal(t, "67890", stderr.String())
46+
assert.False(t, stdout.LimitReached())
47+
assert.False(t, stderr.LimitReached())
48+
}
49+
50+
func TestLimitedWriter_ExceedsLimit(t *testing.T) {
51+
stdout, stderr := newLimitedStdoutStderrWritersPair(10)
52+
53+
n, err := stdout.Write([]byte("12345678"))
54+
require.NoError(t, err)
55+
assert.Equal(t, 8, n)
56+
57+
// This write exceeds the combined limit of 10; only 2 bytes should be written.
58+
n, err = stderr.Write([]byte("abcdef"))
59+
assert.ErrorIs(t, err, errOutputLimitExceeded)
60+
assert.Equal(t, 2, n)
61+
62+
assert.Equal(t, "12345678", stdout.String())
63+
assert.Equal(t, "ab", stderr.String())
64+
assert.True(t, stderr.LimitReached())
65+
}
66+
67+
func TestLimitedWriter_SubsequentWritesAfterLimit(t *testing.T) {
68+
stdout, stderr := newLimitedStdoutStderrWritersPair(5)
69+
70+
_, err := stdout.Write([]byte("12345"))
71+
require.NoError(t, err)
72+
73+
// Limit reached on next write
74+
n, err := stderr.Write([]byte("x"))
75+
assert.ErrorIs(t, err, errOutputLimitExceeded)
76+
assert.Equal(t, 0, n)
77+
assert.True(t, stderr.LimitReached())
78+
79+
// Further writes to stdout also fail
80+
n, err = stdout.Write([]byte("y"))
81+
assert.ErrorIs(t, err, errOutputLimitExceeded)
82+
assert.Equal(t, 0, n)
83+
assert.True(t, stdout.LimitReached())
84+
}
85+
86+
func TestLimitedWriter_SingleWriterExceedsLimit(t *testing.T) {
87+
stdout, stderr := newLimitedStdoutStderrWritersPair(5)
88+
89+
n, err := stdout.Write([]byte("0123456789"))
90+
assert.ErrorIs(t, err, errOutputLimitExceeded)
91+
assert.Equal(t, 5, n)
92+
assert.Equal(t, "01234", stdout.String())
93+
assert.True(t, stdout.LimitReached())
94+
95+
// stderr should also be blocked by the shared counter
96+
n, err = stderr.Write([]byte("a"))
97+
assert.ErrorIs(t, err, errOutputLimitExceeded)
98+
assert.Equal(t, 0, n)
99+
}
100+
101+
func TestLimitedWriter_SharedCounter(t *testing.T) {
102+
stdout, stderr := newLimitedStdoutStderrWritersPair(10)
103+
104+
// Alternate writes between stdout and stderr
105+
stdout.Write([]byte("aa")) // shared = 2
106+
stderr.Write([]byte("bbb")) // shared = 5
107+
stdout.Write([]byte("cc")) // shared = 7
108+
109+
// 3 bytes remaining in the shared budget
110+
n, err := stderr.Write([]byte("dddd"))
111+
assert.ErrorIs(t, err, errOutputLimitExceeded)
112+
assert.Equal(t, 3, n)
113+
114+
assert.Equal(t, "aacc", stdout.String())
115+
assert.Equal(t, "bbbddd", stderr.String())
116+
}
117+
118+
func TestLimitedWriter_ConcurrentWrites(t *testing.T) {
119+
const limit int64 = 1024
120+
stdout, stderr := newLimitedStdoutStderrWritersPair(limit)
121+
122+
chunk := []byte("abcdefghij") // 10 bytes per write
123+
var wg sync.WaitGroup
124+
125+
// Simulate exec.Cmd: two goroutines writing concurrently
126+
wg.Add(2)
127+
go func() {
128+
defer wg.Done()
129+
for {
130+
if _, err := stdout.Write(chunk); err != nil {
131+
return
132+
}
133+
}
134+
}()
135+
go func() {
136+
defer wg.Done()
137+
for {
138+
if _, err := stderr.Write(chunk); err != nil {
139+
return
140+
}
141+
}
142+
}()
143+
wg.Wait()
144+
145+
total := int64(stdout.Len() + stderr.Len())
146+
assert.True(t, stdout.LimitReached() || stderr.LimitReached())
147+
assert.LessOrEqual(t, total, limit+int64(len(chunk)),
148+
"total bytes (%d) should not exceed limit (%d) by more than one chunk (%d)", total, limit, len(chunk))
149+
}

pkg/privateactionrunner/bundles/script/run_predefined_powershell_script.go

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
package com_datadoghq_script
99

1010
import (
11-
"bytes"
1211
"context"
1312
"errors"
1413
"fmt"
@@ -97,30 +96,28 @@ func (h *RunPredefinedPowershellScriptHandler) Run(
9796
}
9897

9998
cmd := newPowershellCommand(ctx, evaluatedScript, script.AllowedEnvVars)
100-
var stdoutBuffer bytes.Buffer
101-
cmd.Stdout = &stdoutBuffer
102-
var stderrBuffer bytes.Buffer
103-
cmd.Stderr = &stderrBuffer
99+
stdoutWriter, stderrWriter := newLimitedStdoutStderrWritersPair(defaultMaxOutputSize)
100+
cmd.Stdout = stdoutWriter
101+
cmd.Stderr = stderrWriter
104102
start := time.Now()
105103
err = cmd.Run()
106104

107-
const maxOutputSize = 10 * 1024 * 1024 // 10MB
108-
if stdoutBuffer.Len()+stderrBuffer.Len() > maxOutputSize {
109-
return nil, errors.New("script output exceeded 10MB limit")
105+
if stdoutWriter.LimitReached() || stderrWriter.LimitReached() {
106+
return nil, newOutputLimitError(defaultMaxOutputSize)
110107
}
111108

112109
if err != nil && !inputs.NoFailOnError {
113110
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
114111
return nil, fmt.Errorf("script execution timed out after %d seconds", inputs.Timeout)
115112
}
116-
return nil, fmt.Errorf("failed to execute script: %w, stderr %s", err, stderrBuffer.String())
113+
return nil, fmt.Errorf("failed to execute script: %w, stderr %s", err, stderrWriter.String())
117114
}
118115

119116
return &RunPredefinedPowershellScriptOutputs{
120117
ExecutedCommand: cmd.String(),
121118
ExitCode: cmd.ProcessState.ExitCode(),
122-
Stdout: formatPowershellOutput(stdoutBuffer.String(), inputs.NoStripTrailingNewline),
123-
Stderr: formatPowershellOutput(stderrBuffer.String(), inputs.NoStripTrailingNewline),
119+
Stdout: formatPowershellOutput(stdoutWriter.String(), inputs.NoStripTrailingNewline),
120+
Stderr: formatPowershellOutput(stderrWriter.String(), inputs.NoStripTrailingNewline),
124121
DurationMillis: int(time.Since(start).Milliseconds()),
125122
}, nil
126123
}

pkg/privateactionrunner/bundles/script/run_predefined_script.go

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
package com_datadoghq_script
99

1010
import (
11-
"bytes"
1211
"context"
1312
"errors"
1413
"fmt"
@@ -79,30 +78,28 @@ func (h *RunPredefinedScriptHandler) Run(
7978
if err != nil {
8079
return nil, fmt.Errorf("invalid command arguments: %w", err)
8180
}
82-
var stdoutBuffer bytes.Buffer
83-
cmd.Stdout = &stdoutBuffer
84-
var stderrBuffer bytes.Buffer
85-
cmd.Stderr = &stderrBuffer
81+
stdoutWriter, stderrWriter := newLimitedStdoutStderrWritersPair(defaultMaxOutputSize)
82+
cmd.Stdout = stdoutWriter
83+
cmd.Stderr = stderrWriter
8684
start := time.Now()
8785
err = cmd.Run()
8886

89-
const maxOutputSize = 10 * 1024 * 1024 // 10MB
90-
if stdoutBuffer.Len()+stderrBuffer.Len() > maxOutputSize {
91-
return nil, errors.New("script output exceeded 10MB limit")
87+
if stdoutWriter.LimitReached() || stderrWriter.LimitReached() {
88+
return nil, newOutputLimitError(defaultMaxOutputSize)
9289
}
9390

9491
if err != nil && !inputs.NoFailOnError {
9592
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
9693
return nil, fmt.Errorf("script execution timed out after %d seconds", inputs.Timeout)
9794
}
98-
return nil, fmt.Errorf("failed to execute command: %w, stderr %s", err, stderrBuffer.String())
95+
return nil, fmt.Errorf("failed to execute command: %w, stderr %s", err, stderrWriter.String())
9996
}
10097

10198
return &RunPredefinedScriptOutputs{
10299
ExecutedCommand: cmd.String(),
103100
ExitCode: cmd.ProcessState.ExitCode(),
104-
Stdout: formatOutput(stdoutBuffer.String(), inputs.NoStripTrailingNewline),
105-
Stderr: formatOutput(stderrBuffer.String(), inputs.NoStripTrailingNewline),
101+
Stdout: formatOutput(stdoutWriter.String(), inputs.NoStripTrailingNewline),
102+
Stderr: formatOutput(stderrWriter.String(), inputs.NoStripTrailingNewline),
106103
DurationMillis: int(time.Since(start).Milliseconds()),
107104
}, nil
108105
}

pkg/privateactionrunner/bundles/script/run_shell_script_experimental.go

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
package com_datadoghq_script
99

1010
import (
11-
"bytes"
1211
"context"
1312
"errors"
1413
"fmt"
@@ -85,14 +84,17 @@ func (h *RunShellScriptHandler) Run(
8584
if err != nil {
8685
return nil, fmt.Errorf("invalid command arguments: %w", err)
8786
}
88-
var stdoutBuffer bytes.Buffer
89-
cmd.Stdout = &stdoutBuffer
90-
var stderrBuffer bytes.Buffer
91-
cmd.Stderr = &stderrBuffer
87+
stdoutWriter, stderrWriter := newLimitedStdoutStderrWritersPair(maxOutputSize)
88+
cmd.Stdout = stdoutWriter
89+
cmd.Stderr = stderrWriter
9290
start := time.Now()
9391
err = cmd.Run()
9492

95-
stdErr := sanitizeErrorMessage(scriptFile.Name(), stderrBuffer.String())
93+
if stdoutWriter.LimitReached() || stderrWriter.LimitReached() {
94+
return nil, newOutputLimitError(defaultMaxOutputSize)
95+
}
96+
97+
stdErr := sanitizeErrorMessage(scriptFile.Name(), stderrWriter.String())
9698

9799
if err != nil {
98100
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
@@ -106,7 +108,7 @@ func (h *RunShellScriptHandler) Run(
106108
return &RunShellScriptOutputs{
107109
ExecutedCommand: cmd.String(),
108110
ExitCode: cmd.ProcessState.ExitCode(),
109-
Stdout: formatOutput(stdoutBuffer.String(), inputs.NoStripTrailingNewline),
111+
Stdout: formatOutput(stdoutWriter.String(), inputs.NoStripTrailingNewline),
110112
Stderr: formatOutput(stdErr, inputs.NoStripTrailingNewline),
111113
DurationMillis: int(time.Since(start).Milliseconds()),
112114
}, nil

0 commit comments

Comments
 (0)