Skip to content

Commit abca2c6

Browse files
committed
fix(direct_proxy): address review feedback for security and consistency
- call_tool: add feature-flag gate (settings.mcpgateway_direct_proxy_enabled) matching list_tools/list_resources/read_resource - call_tool: return error on proxy failure instead of falling through to cache mode (fail-closed) - call_tool: sanitize error message to avoid leaking exception details - read_resource: return empty string directly on access denial instead of raising HTTPException that gets swallowed by outer handler - read_resource: remove _meta from session.read_resource() call (MCP SDK only accepts uri parameter) - GatewayCreate: change gateway_mode from Optional[str] to str with before-validator defaulting None to 'cache' (prevents DB integrity error) - admin: expose direct_proxy_timeout in Connection Timeouts section - list_tools docstring: fix refresh_strategy -> gateway_mode - Fix jq test assertions to match TextContent return type Signed-off-by: Mihai Criveti <crivetimihai@gmail.com>
1 parent af61f72 commit abca2c6

7 files changed

Lines changed: 90 additions & 52 deletions

File tree

mcpgateway/admin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,6 +1406,7 @@ def mask_sensitive(value: Any, key: str) -> Any:
14061406
},
14071407
"Connection Timeouts": {
14081408
"federation_timeout": settings.federation_timeout, # Gateway/server HTTP request timeout
1409+
"mcpgateway_direct_proxy_timeout": settings.mcpgateway_direct_proxy_timeout,
14091410
},
14101411
"Transport": {
14111412
"transport_type": settings.transport_type,

mcpgateway/schemas.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2549,9 +2549,20 @@ class GatewayCreate(BaseModel):
25492549
refresh_interval_seconds: Optional[int] = Field(None, ge=60, description="Per-gateway refresh interval in seconds (minimum 60); uses global default if not set")
25502550

25512551
# Gateway mode configuration
2552-
gateway_mode: Optional[str] = Field(
2553-
default="cache", description="Gateway mode: 'cache' (database caching, default) or 'direct_proxy' (pass-through mode with no caching)", pattern="^(cache|direct_proxy)$"
2554-
)
2552+
gateway_mode: str = Field(default="cache", description="Gateway mode: 'cache' (database caching, default) or 'direct_proxy' (pass-through mode with no caching)", pattern="^(cache|direct_proxy)$")
2553+
2554+
@field_validator("gateway_mode", mode="before")
2555+
@classmethod
2556+
def default_gateway_mode(cls, v: Optional[str]) -> str:
2557+
"""Default gateway_mode to 'cache' when None is provided.
2558+
2559+
Args:
2560+
v: Gateway mode value (may be None).
2561+
2562+
Returns:
2563+
The validated gateway mode string, defaulting to 'cache'.
2564+
"""
2565+
return v if v is not None else "cache"
25552566

25562567
@field_validator("tags")
25572568
@classmethod

mcpgateway/services/resource_service.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2201,13 +2201,8 @@ async def read_resource(
22012201
# Skip session initialize for stateless servers
22022202
# await session.initialize()
22032203

2204-
# Read resource with _meta if provided
2205-
read_params = {"uri": uri}
2206-
if meta_data:
2207-
read_params["_meta"] = meta_data
2208-
logger.debug(f"Forwarding _meta to remote gateway: {meta_data}")
2209-
2210-
result = await session.read_resource(**read_params)
2204+
# Note: MCP SDK read_resource() only accepts uri; _meta is not supported
2205+
result = await session.read_resource(uri=uri)
22112206

22122207
# Convert MCP result to MCP-compliant content models
22132208
# result.contents is a list of TextResourceContents or BlobResourceContents

mcpgateway/transports/streamablehttp_transport.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,7 @@ async def call_tool(name: str, arguments: dict) -> Union[
701701
from mcpgateway.db import Gateway as DbGateway # pylint: disable=import-outside-toplevel
702702

703703
gateway = check_db.execute(select(DbGateway).where(DbGateway.id == gateway_id_from_header)).scalar_one_or_none()
704-
if gateway and getattr(gateway, "gateway_mode", "cache") == "direct_proxy":
704+
if gateway and getattr(gateway, "gateway_mode", "cache") == "direct_proxy" and settings.mcpgateway_direct_proxy_enabled:
705705
# SECURITY: Check gateway access before allowing direct proxy
706706
if not await check_gateway_access(check_db, gateway, user_email, token_teams):
707707
logger.warning(f"Access denied to gateway {gateway_id_from_header} in direct_proxy mode for user {user_email}")
@@ -722,7 +722,7 @@ async def call_tool(name: str, arguments: dict) -> Union[
722722
)
723723
except Exception as e:
724724
logger.error(f"Direct proxy mode failed for gateway {gateway_id_from_header}: {e}")
725-
# Fall through to normal mode on error
725+
return types.CallToolResult(content=[types.TextContent(type="text", text="Direct proxy tool invocation failed")], isError=True)
726726

727727
# Normal mode: use standard tool invocation with normalization
728728
app_user_email = get_user_email_from_context() # Keep for OAuth token selection
@@ -952,7 +952,7 @@ async def list_tools() -> List[types.Tool]:
952952
"""
953953
Lists all tools available to the MCP Server.
954954
955-
Supports two modes based on gateway's refresh_strategy:
955+
Supports two modes based on gateway's gateway_mode:
956956
- 'cache': Returns tools from database (default behavior)
957957
- 'direct_proxy': Proxies the request directly to the remote MCP server
958958
@@ -1278,9 +1278,6 @@ async def read_resource(resource_uri: str) -> Union[str, bytes]:
12781278
Union[str, bytes]: The content of the resource as text or binary data.
12791279
Returns empty string on failure or if no content is found.
12801280
1281-
Raises:
1282-
HTTPException: If access is denied to the gateway in direct_proxy mode.
1283-
12841281
Logs exceptions if any errors occur during reading.
12851282
12861283
Examples:
@@ -1326,7 +1323,6 @@ async def read_resource(resource_uri: str) -> Union[str, bytes]:
13261323
# If X-Context-Forge-Gateway-Id is provided, check if that gateway is in direct_proxy mode
13271324
if gateway_id:
13281325
# Third-Party
1329-
from fastapi import HTTPException # pylint: disable=import-outside-toplevel
13301326
from sqlalchemy import select # pylint: disable=import-outside-toplevel
13311327

13321328
# First-Party
@@ -1337,7 +1333,7 @@ async def read_resource(resource_uri: str) -> Union[str, bytes]:
13371333
# SECURITY: Check gateway access before allowing direct proxy
13381334
if not await check_gateway_access(db, gateway, user_email, token_teams):
13391335
logger.warning(f"Access denied to gateway {gateway_id} in direct_proxy mode for user {user_email}")
1340-
raise HTTPException(status_code=404, detail="Resource not found")
1336+
return ""
13411337

13421338
# Direct proxy mode: forward request to remote MCP server
13431339
# Get _meta from request context if available

tests/unit/mcpgateway/services/test_resource_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5822,7 +5822,7 @@ async def mock_streamable_client_error(*_args, **_kwargs):
58225822

58235823
@pytest.mark.asyncio
58245824
async def test_read_resource_direct_proxy_with_meta(self, resource_service, mock_direct_proxy_resource):
5825-
"""meta_data is forwarded to session.read_resource as _meta."""
5825+
"""meta_data is accepted but not forwarded to session.read_resource (SDK doesn't support _meta)."""
58265826
from contextlib import asynccontextmanager
58275827

58285828
db = self._make_mock_db(mock_direct_proxy_resource)
@@ -5860,9 +5860,9 @@ async def mock_streamable_client(*_args, **_kwargs):
58605860
meta_data=meta,
58615861
)
58625862

5863+
# MCP SDK read_resource() only accepts uri; _meta is not forwarded
58635864
session_mock.read_resource.assert_awaited_once_with(
58645865
uri="http://example.com/dp-resource",
5865-
_meta={"request_id": "trace-abc-123"},
58665866
)
58675867

58685868
@pytest.mark.asyncio

tests/unit/mcpgateway/services/test_tool_service.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,14 +199,14 @@ def all(self):
199199
monkeypatch.setattr("mcpgateway.services.tool_service._compile_jq_filter", lambda _f: DummyProgram())
200200

201201
result = extract_using_jq({"a": 1}, ".a")
202-
assert result == "Error applying jsonpath filter"
202+
assert result == [TextContent(type="text", text="Error applying jsonpath filter")]
203203

204204
def test_extract_using_jq_returns_exception_message(self, monkeypatch):
205-
"""Exceptions during jq execution should return message string."""
205+
"""Exceptions during jq execution should return list with TextContent error."""
206206
monkeypatch.setattr("mcpgateway.services.tool_service._compile_jq_filter", lambda _f: (_ for _ in ()).throw(RuntimeError("boom")))
207207

208208
result = extract_using_jq({"a": 1}, ".a")
209-
assert result == "Error applying jsonpath filter: boom"
209+
assert result == [TextContent(type="text", text="Error applying jsonpath filter: boom")]
210210

211211
def test_tool_service_plugin_env_override(self, monkeypatch):
212212
"""PLUGINS_ENABLED env flag should override settings."""

tests/unit/mcpgateway/transports/test_streamablehttp_transport.py

Lines changed: 64 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6797,10 +6797,8 @@ class MockContent:
67976797
assert result == b"Binary data"
67986798

67996799
@pytest.mark.asyncio
6800-
async def test_read_resource_direct_proxy_access_denied_raises_404(self):
6801-
"""Test read_resource raises HTTPException 404 when gateway access is denied."""
6802-
from fastapi import HTTPException
6803-
6800+
async def test_read_resource_direct_proxy_access_denied_returns_empty(self):
6801+
"""Test read_resource returns empty string when gateway access is denied."""
68046802
mock_gateway = MagicMock()
68056803
mock_gateway.id = "gw-direct"
68066804
mock_gateway.gateway_mode = "direct_proxy"
@@ -6818,11 +6816,9 @@ async def mock_get_db():
68186816

68196817
with patch("mcpgateway.transports.streamablehttp_transport.get_db", mock_get_db):
68206818
with patch("mcpgateway.transports.streamablehttp_transport.check_gateway_access", return_value=False):
6821-
# The HTTPException is raised but caught by outer try/except and logged
6822-
# So we expect empty string return, not an exception
68236819
result = await tr.read_resource("file:///test.txt")
68246820

6825-
# Access denied returns empty string (exception is caught and logged)
6821+
# Access denied returns empty string directly (no exception raised)
68266822
assert result == ""
68276823

68286824
@pytest.mark.asyncio
@@ -6859,6 +6855,12 @@ async def mock_get_db():
68596855
class TestCallToolDirectProxy:
68606856
"""Test direct_proxy mode in the call_tool handler."""
68616857

6858+
@pytest.fixture(autouse=True)
6859+
def enable_direct_proxy(self):
6860+
"""Enable direct_proxy feature flag for all tests in this class."""
6861+
with patch.object(tr.settings, "mcpgateway_direct_proxy_enabled", True):
6862+
yield
6863+
68626864
@pytest.mark.asyncio
68636865
async def test_call_tool_direct_proxy_success(self):
68646866
"""Test call_tool returns CallToolResult from invoke_tool_direct when
@@ -6938,8 +6940,10 @@ async def mock_get_db():
69386940
assert result.content[0].text == "Tool not found: secret_tool"
69396941

69406942
@pytest.mark.asyncio
6941-
async def test_call_tool_direct_proxy_exception_falls_through(self):
6942-
"""Test call_tool falls through to normal mode when invoke_tool_direct raises."""
6943+
async def test_call_tool_direct_proxy_exception_returns_error(self):
6944+
"""Test call_tool returns error when invoke_tool_direct raises (no fallback to cache mode)."""
6945+
from mcp import types as mcp_types
6946+
69436947
mock_gateway = MagicMock()
69446948
mock_gateway.id = "gw-direct"
69456949
mock_gateway.gateway_mode = "direct_proxy"
@@ -6956,20 +6960,6 @@ async def mock_get_db():
69566960
# invoke_tool_direct raises an exception
69576961
mock_invoke_direct = AsyncMock(side_effect=RuntimeError("connection failed"))
69586962

6959-
# Normal mode invoke_tool returns a result with content
6960-
# Attributes must be explicit (not auto-created by MagicMock) so the
6961-
# content normalization code in call_tool works correctly.
6962-
mock_content_item = MagicMock(spec=[])
6963-
mock_content_item.type = "text"
6964-
mock_content_item.text = "normal result"
6965-
mock_content_item.annotations = None
6966-
mock_content_item.meta = None
6967-
mock_content_item.size = None
6968-
normal_result = MagicMock(spec=[])
6969-
normal_result.content = [mock_content_item]
6970-
normal_result.structuredContent = None
6971-
mock_invoke_normal = AsyncMock(return_value=normal_result)
6972-
69736963
tr.server_id_var.set("server-123")
69746964
tr.request_headers_var.set({"x-context-forge-gateway-id": "gw-direct"})
69756965
tr.user_context_var.set({"email": "user@test.com", "teams": ["team1"], "is_admin": False})
@@ -6978,15 +6968,14 @@ async def mock_get_db():
69786968
with patch("mcpgateway.transports.streamablehttp_transport.extract_gateway_id_from_headers", return_value="gw-direct"):
69796969
with patch("mcpgateway.transports.streamablehttp_transport.check_gateway_access", new_callable=AsyncMock, return_value=True):
69806970
with patch.object(tr.tool_service, "invoke_tool_direct", mock_invoke_direct):
6981-
with patch.object(tr.tool_service, "invoke_tool", mock_invoke_normal):
6982-
with patch("mcpgateway.transports.streamablehttp_transport.settings") as mock_settings:
6983-
mock_settings.mcpgateway_session_affinity_enabled = False
6984-
result = await tr.call_tool("my_tool", {"arg": "value"})
6971+
result = await tr.call_tool("my_tool", {"arg": "value"})
69856972

69866973
# invoke_tool_direct was called and raised
69876974
mock_invoke_direct.assert_awaited_once()
6988-
# Normal mode invoke_tool was called as fallback
6989-
mock_invoke_normal.assert_awaited_once()
6975+
# Should return error result, NOT fall through to normal mode
6976+
assert isinstance(result, mcp_types.CallToolResult)
6977+
assert result.isError is True
6978+
assert result.content[0].text == "Direct proxy tool invocation failed"
69906979

69916980
@pytest.mark.asyncio
69926981
async def test_call_tool_direct_proxy_gateway_not_direct_proxy_falls_through(self):
@@ -7029,3 +7018,49 @@ async def mock_get_db():
70297018

70307019
# Normal mode invoke_tool was called since gateway is not direct_proxy
70317020
mock_invoke_normal.assert_awaited_once()
7021+
7022+
@pytest.mark.asyncio
7023+
async def test_call_tool_direct_proxy_feature_disabled_falls_through(self):
7024+
"""Test call_tool falls through to normal mode when feature flag is disabled."""
7025+
mock_gateway = MagicMock()
7026+
mock_gateway.id = "gw-direct"
7027+
mock_gateway.gateway_mode = "direct_proxy"
7028+
7029+
mock_db = MagicMock()
7030+
mock_db.execute = MagicMock(
7031+
return_value=MagicMock(scalar_one_or_none=MagicMock(return_value=mock_gateway))
7032+
)
7033+
7034+
@asynccontextmanager
7035+
async def mock_get_db():
7036+
yield mock_db
7037+
7038+
mock_content_item = MagicMock(spec=[])
7039+
mock_content_item.type = "text"
7040+
mock_content_item.text = "cache result"
7041+
mock_content_item.annotations = None
7042+
mock_content_item.meta = None
7043+
mock_content_item.size = None
7044+
normal_result = MagicMock(spec=[])
7045+
normal_result.content = [mock_content_item]
7046+
normal_result.structuredContent = None
7047+
mock_invoke_normal = AsyncMock(return_value=normal_result)
7048+
mock_invoke_direct = AsyncMock()
7049+
7050+
tr.server_id_var.set("server-123")
7051+
tr.request_headers_var.set({"x-context-forge-gateway-id": "gw-direct"})
7052+
tr.user_context_var.set({"email": "user@test.com", "teams": ["team1"], "is_admin": False})
7053+
7054+
with patch("mcpgateway.transports.streamablehttp_transport.get_db", mock_get_db):
7055+
with patch("mcpgateway.transports.streamablehttp_transport.extract_gateway_id_from_headers", return_value="gw-direct"):
7056+
with patch.object(tr.tool_service, "invoke_tool_direct", mock_invoke_direct):
7057+
with patch.object(tr.tool_service, "invoke_tool", mock_invoke_normal):
7058+
with patch("mcpgateway.transports.streamablehttp_transport.settings") as mock_settings:
7059+
mock_settings.mcpgateway_direct_proxy_enabled = False
7060+
mock_settings.mcpgateway_session_affinity_enabled = False
7061+
result = await tr.call_tool("my_tool", {"arg": "value"})
7062+
7063+
# Direct proxy was NOT called since feature flag is disabled
7064+
mock_invoke_direct.assert_not_awaited()
7065+
# Normal mode was used instead
7066+
mock_invoke_normal.assert_awaited_once()

0 commit comments

Comments
 (0)