Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 31 additions & 16 deletions fastapi/dependencies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@
"pip install python-multipart\n"
)

# Sentinel value for unspecified fields
_NOT_SPECIFIED = object()


def ensure_multipart_is_installed() -> None:
try:
Expand Down Expand Up @@ -690,7 +693,7 @@ async def solve_dependencies(
def _validate_value_with_model_field(
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
) -> Tuple[Any, List[Any]]:
if value is None:
if value is _NOT_SPECIFIED:
if field.required:
return None, [get_missing_field_error(loc=loc)]
else:
Expand All @@ -712,20 +715,24 @@ def _get_multidict_value(
if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
value = values.getlist(alias)
else:
value = values.get(alias, None)
value = values.get(alias, _NOT_SPECIFIED)
if (
value is None
or (
isinstance(field.field_info, params.Form)
and isinstance(value, str) # For type checks
and value == ""
)
or (is_sequence_field(field) and len(value) == 0)
isinstance(field.field_info, params.Form)
and isinstance(value, str) # For type checks
and value == ""
Comment on lines +720 to +722
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please clarify why you removed the or (is_sequence_field(field) and len(value) == 0) part from this condition?

):
# Empty strings in a form can be a representation of None values
_, error = field.validate(None, {}, loc=())
# If None is an accepted value for this field, use that
if error is None:
value = None
Comment on lines +724 to +728
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't really like this part.. Can we try getting this information from field annotation?


if value == _NOT_SPECIFIED or (is_sequence_field(field) and len(value) == 0):
if field.required:
return
return _NOT_SPECIFIED
else:
return deepcopy(field.default)

return value


Expand Down Expand Up @@ -763,7 +770,7 @@ def request_params_to_args(
else field.name.replace("_", "-")
)
value = _get_multidict_value(field, received_params, alias=alias)
if value is not None:
if value != _NOT_SPECIFIED:
params_to_process[field.name] = value
processed_keys.add(alias or field.alias)
processed_keys.add(field.name)
Expand Down Expand Up @@ -830,6 +837,8 @@ async def _extract_form_body(
first_field = body_fields[0]
first_field_info = first_field.field_info

processed_keys = set()

for field in body_fields:
value = _get_multidict_value(field, received_body)
if (
Expand Down Expand Up @@ -857,10 +866,13 @@ async def process_fn(
for sub_value in value:
tg.start_soon(process_fn, sub_value.read)
value = serialize_sequence_value(field=field, value=results)
if value is not None:
if value != _NOT_SPECIFIED:
values[field.alias] = value
processed_keys.add(field.alias)
processed_keys.add(field.name)

for key, value in received_body.items():
if key not in values:
if key not in processed_keys:
values[key] = value
return values

Expand Down Expand Up @@ -888,15 +900,18 @@ async def request_body_to_args(
if single_not_embedded_field:
loc: Tuple[str, ...] = ("body",)
v_, errors_ = _validate_value_with_model_field(
field=first_field, value=body_to_process, values=values, loc=loc
field=first_field,
value=body_to_process if body_to_process is not None else _NOT_SPECIFIED,
values=values,
loc=loc,
)
return {first_field.name: v_}, errors_
for field in body_fields:
loc = ("body", field.alias)
value: Optional[Any] = None
value: Optional[Any] = _NOT_SPECIFIED
if body_to_process is not None:
try:
value = body_to_process.get(field.alias)
value = _get_multidict_value(field, values=body_to_process)
# If the received body is a list, not a dict
except AttributeError:
errors.append(get_missing_field_error(loc))
Expand Down
47 changes: 47 additions & 0 deletions tests/test_forms_nullable_param.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Optional
from uuid import UUID, uuid4

import pytest
from fastapi import FastAPI, Form
from fastapi.testclient import TestClient
from typing_extensions import Annotated

app = FastAPI()

default_uuid = uuid4()


@app.post("/form-optional/")
def post_form_optional(
test_id: Annotated[Optional[UUID], Form(alias="testId")] = default_uuid,
) -> Optional[UUID]:
return test_id


@app.post("/form-required/")
def post_form_required(
test_id: Annotated[Optional[UUID], Form(alias="testId")],
) -> Optional[UUID]:
return test_id


client = TestClient(app)


def test_unspecified_optional() -> None:
response = client.post("/form-optional/", data={})
assert response.status_code == 200, response.text
assert response.json() == str(default_uuid)


def test_unspecified_required() -> None:
response = client.post("/form-required/", data={})
assert response.status_code == 422, response.text


@pytest.mark.parametrize("url", ["/form-optional/", "/form-required/"])
@pytest.mark.parametrize("test_id", [None, str(uuid4())])
def test_specified(url: str, test_id: Optional[str]) -> None:
response = client.post(url, data={"testId": test_id})
assert response.status_code == 200, response.text
assert response.json() == test_id
4 changes: 2 additions & 2 deletions tests/test_forms_single_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ def test_no_data():
"type": "missing",
"loc": ["body", "username"],
"msg": "Field required",
"input": {"tags": ["foo", "bar"], "with": "nothing"},
"input": {"age": None, "tags": ["foo", "bar"], "with": "nothing"},
},
{
"type": "missing",
"loc": ["body", "lastname"],
"msg": "Field required",
"input": {"tags": ["foo", "bar"], "with": "nothing"},
"input": {"age": None, "tags": ["foo", "bar"], "with": "nothing"},
},
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_cookie_param_model_invalid(client: TestClient):
"type": "missing",
"loc": ["cookie", "session_id"],
"msg": "Field required",
"input": {},
"input": {"fatebook_tracker": None, "googall_tracker": None},
}
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_cookie_param_model_invalid(client: TestClient):
"type": "missing",
"loc": ["cookie", "session_id"],
"msg": "Field required",
"input": {},
"input": {"fatebook_tracker": None, "googall_tracker": None},
}
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def test_header_param_model_invalid(client: TestClient):
"accept-encoding": "gzip, deflate",
"connection": "keep-alive",
"user-agent": "testclient",
"if_modified_since": None,
"traceparent": None,
},
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,12 @@ def test_header_param_model_invalid(client: TestClient):
"type": "missing",
"loc": ["header", "save_data"],
"msg": "Field required",
"input": {"x_tag": [], "host": "testserver"},
"input": {
"x_tag": [],
"host": "testserver",
"if_modified_since": None,
"traceparent": None,
},
}
)
| IsDict(
Expand Down