Skip to content

Commit c7f0912

Browse files
authored
feat: add panic handling for all the plugin handlers (#287)
Signed-off-by: Gergely Brautigam <182850+Skarlso@users.noreply.github.com>
1 parent 60bc4a4 commit c7f0912

2 files changed

Lines changed: 56 additions & 1 deletion

File tree

bindings/go/plugin/client/sdk/plugin.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ func (p *Plugin) listen(ctx context.Context) error {
150150

151151
m := http.NewServeMux()
152152
for _, h := range p.handlers {
153-
m.HandleFunc(h.Location, h.Handler)
153+
m.HandleFunc(h.Location, p.panicRecovery(h.Handler))
154154
}
155155

156156
m.HandleFunc("/shutdown", p.Shutdown)
@@ -186,6 +186,22 @@ func (p *Plugin) listen(ctx context.Context) error {
186186
return server.Serve(conn)
187187
}
188188

189+
func (p *Plugin) panicRecovery(f func(w http.ResponseWriter, r *http.Request)) http.HandlerFunc {
190+
return func(w http.ResponseWriter, r *http.Request) {
191+
defer func() {
192+
if err := recover(); err != nil {
193+
p.logger.ErrorContext(r.Context(), "panic recovered", "error", err)
194+
plugins.NewError(
195+
errors.New("panic recovered"),
196+
http.StatusInternalServerError).
197+
Write(w)
198+
}
199+
}()
200+
201+
f(w, r)
202+
}
203+
}
204+
189205
func (p *Plugin) determineLocation() (_ string, err error) {
190206
switch p.Config.Type {
191207
case types.Socket:

bindings/go/plugin/client/sdk/plugin_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,45 @@ func TestHealthCheckInvalidMethod(t *testing.T) {
201201
r.Contains(string(content), "this endpoint may only be called with either HEAD or GET method")
202202
}
203203

204+
func TestPanicRecovery(t *testing.T) {
205+
r := require.New(t)
206+
location := "/tmp/test-plugin-panic-plugin.socket"
207+
output := bytes.NewBuffer(nil)
208+
ctx := context.Background()
209+
p := NewPlugin(ctx, slog.Default(), types.Config{
210+
ID: "test-plugin-panic",
211+
Type: types.Socket,
212+
PluginType: types.ComponentVersionRepositoryPluginType,
213+
}, output)
214+
215+
t.Cleanup(func() {
216+
r.NoError(os.RemoveAll(location))
217+
})
218+
219+
r.NoError(p.RegisterHandlers(endpoints.Handler{
220+
Handler: func(writer http.ResponseWriter, request *http.Request) {
221+
panic("test panic")
222+
},
223+
Location: "/panic-endpoint",
224+
}))
225+
226+
go func() {
227+
_ = p.Start(ctx)
228+
}()
229+
230+
httpClient := createHttpClient(location)
231+
waitForPlugin(r, httpClient)
232+
233+
resp, err := httpClient.Get("http://unix/panic-endpoint")
234+
r.NoError(err)
235+
r.Equal(http.StatusInternalServerError, resp.StatusCode)
236+
content, err := io.ReadAll(resp.Body)
237+
r.NoError(err)
238+
r.Contains(string(content), "panic recovered")
239+
240+
r.NoError(p.GracefulShutdown(ctx))
241+
}
242+
204243
func waitForPlugin(r *require.Assertions, httpClient *http.Client) {
205244
r.Eventually(func() bool {
206245
resp, err := httpClient.Get("http://unix/healthz")

0 commit comments

Comments
 (0)