-
-
Notifications
You must be signed in to change notification settings - Fork 8.5k
🐛 Fix support for nullable form fields #12502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -87,6 +87,9 @@ | |
| "pip install python-multipart\n" | ||
| ) | ||
|
|
||
| # Sentinel value for unspecified fields | ||
| _NOT_SPECIFIED = object() | ||
|
|
||
|
|
||
| def ensure_multipart_is_installed() -> None: | ||
| try: | ||
|
|
@@ -690,7 +693,7 @@ async def solve_dependencies( | |
| def _validate_value_with_model_field( | ||
| *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...] | ||
| ) -> Tuple[Any, List[Any]]: | ||
| if value is None: | ||
| if value is _NOT_SPECIFIED: | ||
| if field.required: | ||
| return None, [get_missing_field_error(loc=loc)] | ||
| else: | ||
|
|
@@ -712,20 +715,24 @@ def _get_multidict_value( | |
| if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)): | ||
| value = values.getlist(alias) | ||
| else: | ||
| value = values.get(alias, None) | ||
| value = values.get(alias, _NOT_SPECIFIED) | ||
| if ( | ||
| value is None | ||
| or ( | ||
| isinstance(field.field_info, params.Form) | ||
| and isinstance(value, str) # For type checks | ||
| and value == "" | ||
| ) | ||
| or (is_sequence_field(field) and len(value) == 0) | ||
| isinstance(field.field_info, params.Form) | ||
| and isinstance(value, str) # For type checks | ||
| and value == "" | ||
|
Comment on lines
+720
to
+722
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please clarify why you removed the |
||
| ): | ||
| # Empty strings in a form can be a representation of None values | ||
| _, error = field.validate(None, {}, loc=()) | ||
| # If None is an accepted value for this field, use that | ||
| if error is None: | ||
| value = None | ||
|
Comment on lines
+724
to
+728
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't really like this part.. Can we try getting this information from field annotation? |
||
|
|
||
| if value == _NOT_SPECIFIED or (is_sequence_field(field) and len(value) == 0): | ||
| if field.required: | ||
| return | ||
| return _NOT_SPECIFIED | ||
| else: | ||
| return deepcopy(field.default) | ||
|
|
||
| return value | ||
|
|
||
|
|
||
|
|
@@ -763,7 +770,7 @@ def request_params_to_args( | |
| else field.name.replace("_", "-") | ||
| ) | ||
| value = _get_multidict_value(field, received_params, alias=alias) | ||
| if value is not None: | ||
| if value != _NOT_SPECIFIED: | ||
| params_to_process[field.name] = value | ||
| processed_keys.add(alias or field.alias) | ||
| processed_keys.add(field.name) | ||
|
|
@@ -830,6 +837,8 @@ async def _extract_form_body( | |
| first_field = body_fields[0] | ||
| first_field_info = first_field.field_info | ||
|
|
||
| processed_keys = set() | ||
|
|
||
| for field in body_fields: | ||
| value = _get_multidict_value(field, received_body) | ||
| if ( | ||
|
|
@@ -857,10 +866,13 @@ async def process_fn( | |
| for sub_value in value: | ||
| tg.start_soon(process_fn, sub_value.read) | ||
| value = serialize_sequence_value(field=field, value=results) | ||
| if value is not None: | ||
| if value != _NOT_SPECIFIED: | ||
| values[field.alias] = value | ||
| processed_keys.add(field.alias) | ||
| processed_keys.add(field.name) | ||
|
|
||
| for key, value in received_body.items(): | ||
| if key not in values: | ||
| if key not in processed_keys: | ||
| values[key] = value | ||
| return values | ||
|
|
||
|
|
@@ -888,15 +900,18 @@ async def request_body_to_args( | |
| if single_not_embedded_field: | ||
| loc: Tuple[str, ...] = ("body",) | ||
| v_, errors_ = _validate_value_with_model_field( | ||
| field=first_field, value=body_to_process, values=values, loc=loc | ||
| field=first_field, | ||
| value=body_to_process if body_to_process is not None else _NOT_SPECIFIED, | ||
| values=values, | ||
| loc=loc, | ||
| ) | ||
| return {first_field.name: v_}, errors_ | ||
| for field in body_fields: | ||
| loc = ("body", field.alias) | ||
| value: Optional[Any] = None | ||
| value: Optional[Any] = _NOT_SPECIFIED | ||
| if body_to_process is not None: | ||
| try: | ||
| value = body_to_process.get(field.alias) | ||
| value = _get_multidict_value(field, values=body_to_process) | ||
| # If the received body is a list, not a dict | ||
| except AttributeError: | ||
| errors.append(get_missing_field_error(loc)) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| from typing import Optional | ||
| from uuid import UUID, uuid4 | ||
|
|
||
| import pytest | ||
| from fastapi import FastAPI, Form | ||
| from fastapi.testclient import TestClient | ||
| from typing_extensions import Annotated | ||
|
|
||
| app = FastAPI() | ||
|
|
||
| default_uuid = uuid4() | ||
|
|
||
|
|
||
| @app.post("/form-optional/") | ||
| def post_form_optional( | ||
| test_id: Annotated[Optional[UUID], Form(alias="testId")] = default_uuid, | ||
| ) -> Optional[UUID]: | ||
| return test_id | ||
|
|
||
|
|
||
| @app.post("/form-required/") | ||
| def post_form_required( | ||
| test_id: Annotated[Optional[UUID], Form(alias="testId")], | ||
| ) -> Optional[UUID]: | ||
| return test_id | ||
|
|
||
|
|
||
| client = TestClient(app) | ||
|
|
||
|
|
||
| def test_unspecified_optional() -> None: | ||
| response = client.post("/form-optional/", data={}) | ||
| assert response.status_code == 200, response.text | ||
| assert response.json() == str(default_uuid) | ||
|
|
||
|
|
||
| def test_unspecified_required() -> None: | ||
| response = client.post("/form-required/", data={}) | ||
| assert response.status_code == 422, response.text | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("url", ["/form-optional/", "/form-required/"]) | ||
| @pytest.mark.parametrize("test_id", [None, str(uuid4())]) | ||
| def test_specified(url: str, test_id: Optional[str]) -> None: | ||
| response = client.post(url, data={"testId": test_id}) | ||
| assert response.status_code == 200, response.text | ||
| assert response.json() == test_id |
Uh oh!
There was an error while loading. Please reload this page.