diff --git a/middleware/adaptor/adaptor.go b/middleware/adaptor/adaptor.go index ea65dfc5..c266142a 100644 --- a/middleware/adaptor/adaptor.go +++ b/middleware/adaptor/adaptor.go @@ -54,13 +54,24 @@ func ConvertRequest(c fiber.Ctx, forServer bool) (*http.Request, error) { // CopyContextToFiberContext copies the values of context.Context to a fasthttp.RequestCtx. func CopyContextToFiberContext(context any, requestContext *fasthttp.RequestCtx) { - contextValues := reflect.ValueOf(context).Elem() - contextKeys := reflect.TypeOf(context).Elem() + contextValue := reflect.ValueOf(context) + contextType := reflect.TypeOf(context) - if contextKeys.Kind() != reflect.Struct { + if contextValue.Kind() == reflect.Ptr { + if contextValue.IsNil() { + return + } + contextValue = contextValue.Elem() + contextType = contextType.Elem() + } + + if contextType.Kind() != reflect.Struct { return } + contextValues := contextValue + contextKeys := contextType + var lastKey any for i := 0; i < contextValues.NumField(); i++ { reflectValue := contextValues.Field(i) diff --git a/middleware/adaptor/adaptor_test.go b/middleware/adaptor/adaptor_test.go index 809939e2..f5aa1a85 100644 --- a/middleware/adaptor/adaptor_test.go +++ b/middleware/adaptor/adaptor_test.go @@ -4,6 +4,7 @@ package adaptor import ( "bytes" "context" + "errors" "fmt" "io" "net" @@ -433,6 +434,30 @@ func (r *netHTTPBody) Close() error { return nil } +func createTestRequest(method, uri, remoteAddr string, body io.Reader) *http.Request { + r := &http.Request{ + Method: method, + RequestURI: uri, + RemoteAddr: remoteAddr, + Header: make(http.Header), + Body: http.NoBody, + } + if body != nil { + if rc, ok := body.(io.ReadCloser); ok { + r.Body = rc + } else { + r.Body = io.NopCloser(body) + } + } + return r +} + +func executeHandlerTest(_ *testing.T, handler http.HandlerFunc, req *http.Request) *netHTTPResponseWriter { + w := &netHTTPResponseWriter{} + handler.ServeHTTP(w, req) + return w +} + type netHTTPResponseWriter struct { h http.Header body []byte @@ -465,24 +490,167 @@ func (w *netHTTPResponseWriter) Write(p []byte) (int, error) { func Test_ConvertRequest(t *testing.T) { t.Parallel() - app := fiber.New() + t.Run("successful conversion", func(t *testing.T) { + t.Parallel() + app := fiber.New() - app.Get("/test", func(c fiber.Ctx) error { - httpReq, err := ConvertRequest(c, false) + app.Get("/test", func(c fiber.Ctx) error { + httpReq, err := ConvertRequest(c, false) + if err != nil { + return err + } + return c.SendString("Request URL: " + httpReq.URL.String()) + }) + + resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test?hello=world&another=test", nil)) + require.NoError(t, err, "app.Test(req)") + require.Equal(t, http.StatusOK, resp.StatusCode, "Status code") + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "Request URL: /test?hello=world&another=test", string(body)) + }) + + t.Run("conversion error handling", func(t *testing.T) { + t.Parallel() + // Test error case by creating a context with an invalid URL that will cause fasthttpadaptor.ConvertRequest to fail + app := fiber.New() + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + + // Create a malformed request URI that should cause conversion to fail + ctx.Request().SetRequestURI("http://[::1:bad:url") // Invalid URL format + ctx.Request().Header.SetMethod(fiber.MethodGet) + + _, err := ConvertRequest(ctx, true) // Use forServer=true which does more validation + if err == nil { + // If the above doesn't fail, try a different approach + ctx.Request().SetRequestURI("\x00\x01\x02") // Invalid characters in URI + _, err = ConvertRequest(ctx, true) + } + // Note: This test may pass if fasthttpadaptor is very permissive + // The important thing is that our function doesn't panic if err != nil { - return err + require.Error(t, err, "Expected error from fasthttpadaptor.ConvertRequest") } + }) +} + +func Test_CopyContextToFiberContext(t *testing.T) { + t.Parallel() - return c.SendString("Request URL: " + httpReq.URL.String()) + t.Run("unsupported context type", func(t *testing.T) { + t.Parallel() + // Test with non-struct context (should return early) + var fctx fasthttp.RequestCtx + stringContext := "not a struct" + + // This should not panic and should handle the non-struct gracefully + CopyContextToFiberContext(&stringContext, &fctx) + // No assertions needed - just ensuring it doesn't panic }) - resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test?hello=world&another=test", nil)) - require.NoError(t, err, "app.Test(req)") - require.Equal(t, http.StatusOK, resp.StatusCode, "Status code") + t.Run("context with unknown field", func(t *testing.T) { + t.Parallel() + // Test the default case (continue statement coverage) + type customContext struct { + UnknownField string + } + + var fctx fasthttp.RequestCtx + ctx := customContext{UnknownField: "test"} - body, err := io.ReadAll(resp.Body) + // This should hit the default case and continue + CopyContextToFiberContext(&ctx, &fctx) + // No assertions needed - just ensuring it doesn't panic and continues + }) +} + +func Test_HTTPMiddleware_ErrorHandling(t *testing.T) { + t.Parallel() + + // Test middleware that returns an error from HTTPHandler + errorMiddleware := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // This will cause an error in the underlying handler + w.WriteHeader(http.StatusInternalServerError) + next.ServeHTTP(w, r) + }) + } + + fiberHandler := func(c fiber.Ctx) error { + return fiber.NewError(fiber.StatusBadRequest, "test error") + } + + app := fiber.New() + app.Use(HTTPMiddleware(errorMiddleware)) + app.Get("/error", fiberHandler) + + resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/error", nil)) require.NoError(t, err) - require.Equal(t, "Request URL: /test?hello=world&another=test", string(body)) + // The error should be handled by the error handler + require.Equal(t, fiber.StatusBadRequest, resp.StatusCode) +} + +func Test_FiberHandler_IOError(t *testing.T) { + t.Parallel() + + // Test io.Copy error by using a failing reader + fiberH := func(c fiber.Ctx) error { + return c.SendString("should not reach here") + } + handlerFunc := FiberHandlerFunc(fiberH) + + // Create a reader that fails + failingReader := &failingReader{} + + r := &http.Request{ + Method: http.MethodPost, + RequestURI: "/test", + Body: failingReader, + ContentLength: 100, // Set content length so it tries to read + Header: make(http.Header), + } + + w := &netHTTPResponseWriter{} + handlerFunc.ServeHTTP(w, r) + + // Should return 500 due to io.Copy error + require.Equal(t, http.StatusInternalServerError, w.StatusCode()) +} + +func Test_FiberHandler_WithErrorInHandler(t *testing.T) { + t.Parallel() + + // Test error handling in fiber handler + fiberH := func(c fiber.Ctx) error { + return fiber.NewError(fiber.StatusTeapot, "I'm a teapot") + } + handlerFunc := FiberHandlerFunc(fiberH) + + r := &http.Request{ + Method: http.MethodGet, + RequestURI: "/test", + Header: make(http.Header), + Body: http.NoBody, + } + + w := &netHTTPResponseWriter{} + handlerFunc.ServeHTTP(w, r) + + // Should return the error status + require.Equal(t, fiber.StatusTeapot, w.StatusCode()) +} + +// failingReader always returns an error when Read is called +type failingReader struct{} + +func (f *failingReader) Read(p []byte) (int, error) { + return 0, errors.New("simulated read error") +} + +func (f *failingReader) Close() error { + return nil } // Benchmark for FiberHandlerFunc @@ -667,6 +835,143 @@ func Benchmark_HTTPHandler(b *testing.B) { require.NoError(b, err) } +func Test_resolveRemoteAddr(t *testing.T) { + t.Parallel() + + tests := []struct { + localAddr any + name string + remoteAddr string + errorContains string + expectError bool + }{ + { + name: "valid TCP address with port", + remoteAddr: "192.168.1.1:8080", + localAddr: nil, + expectError: false, + }, + { + name: "valid TCP address without port - should add default port 80", + remoteAddr: "192.168.1.1", + localAddr: nil, + expectError: false, + }, + { + name: "unix socket - should return local addr", + remoteAddr: "irrelevant", + localAddr: &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"}, + expectError: false, + }, + { + name: "invalid address - should fail", + remoteAddr: "[invalid:address:format", + localAddr: nil, + expectError: true, + errorContains: "failed to resolve TCP address:", + }, + { + name: "invalid address after adding port - should fail", + remoteAddr: "[invalid", + localAddr: nil, + expectError: true, + errorContains: "failed to resolve TCP address after adding port:", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + addr, err := resolveRemoteAddr(tt.remoteAddr, tt.localAddr) + + if tt.expectError { + require.Error(t, err) + if tt.errorContains != "" { + require.Contains(t, err.Error(), tt.errorContains) + } + require.Nil(t, addr) + } else { + require.NoError(t, err) + require.NotNil(t, addr) + } + }) + } +} + +func Test_isUnixNetwork(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + network string + expected bool + }{ + {"unix", "unix", true}, + {"unixgram", "unixgram", true}, + {"unixpacket", "unixpacket", true}, + {"tcp", "tcp", false}, + {"tcp4", "tcp4", false}, + {"tcp6", "tcp6", false}, + {"udp", "udp", false}, + {"empty", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := isUnixNetwork(tt.network) + require.Equal(t, tt.expected, result) + }) + } +} + +func Test_FiberHandler_ErrorFallback(t *testing.T) { + t.Parallel() + + // Test case where resolveRemoteAddr fails and falls back to nil + fiberH := func(c fiber.Ctx) error { + return c.SendString("success") + } + handlerFunc := FiberHandlerFunc(fiberH) + + // Use helper function for cleaner test setup + req := createTestRequest(http.MethodGet, "/test", "[invalid:address:format", nil) + w := executeHandlerTest(t, handlerFunc, req) + + // Should still work despite the invalid remote address + require.Equal(t, http.StatusOK, w.StatusCode()) + require.Equal(t, "success", string(w.body)) +} + +func Test_FiberHandler_WithUnixSocket(t *testing.T) { + t.Parallel() + + // Test case where request has unix socket context + fiberH := func(c fiber.Ctx) error { + return c.SendString("unix socket success") + } + handlerFunc := FiberHandlerFunc(fiberH) + + // Create a context with unix socket local address + unixAddr := &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"} + ctx := context.WithValue(context.Background(), http.LocalAddrContextKey, unixAddr) + + r := &http.Request{ + Method: http.MethodGet, + RequestURI: "/test", + RemoteAddr: "someremoteaddr", // This will be ignored due to unix socket + Header: make(http.Header), + Body: http.NoBody, + } + r = r.WithContext(ctx) + + w := &netHTTPResponseWriter{} + handlerFunc.ServeHTTP(w, r) + + require.Equal(t, http.StatusOK, w.StatusCode()) + require.Equal(t, "unix socket success", string(w.body)) +} + func TestUnixSocketAdaptor(t *testing.T) { dir := t.TempDir() socketPath := filepath.Join(dir, "test.sock")