Skip to content

Commit ef2da6f

Browse files
authored
Port dag.test to Task SDK (#50300)
closes #45549 Key changes: - Moves `dag.test` implementation to Task SDK, leveraging the existing in-process execution infrastructure - Adds `JWTBearerTIPathDep` for proper task instance path validation - Updates `InProcessExecutionAPI` to support task instance validation - Removes legacy `dag.test` implementation from DAG class The changes ensure that `dag.test` uses the same execution path as regular task execution.
1 parent 2622db3 commit ef2da6f

14 files changed

Lines changed: 539 additions & 375 deletions

File tree

airflow-core/src/airflow/api_fastapi/execution_api/app.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,11 @@ class InProcessExecutionAPI:
225225
def app(self):
226226
if not self._app:
227227
from airflow.api_fastapi.execution_api.app import create_task_execution_api_app
228-
from airflow.api_fastapi.execution_api.deps import JWTBearerDep, JWTRefresherDep
228+
from airflow.api_fastapi.execution_api.deps import (
229+
JWTBearerDep,
230+
JWTBearerTIPathDep,
231+
JWTRefresherDep,
232+
)
229233
from airflow.api_fastapi.execution_api.routes.connections import has_connection_access
230234
from airflow.api_fastapi.execution_api.routes.variables import has_variable_access
231235
from airflow.api_fastapi.execution_api.routes.xcoms import has_xcom_access
@@ -235,6 +239,7 @@ def app(self):
235239
async def always_allow(): ...
236240

237241
self._app.dependency_overrides[JWTBearerDep.dependency] = always_allow
242+
self._app.dependency_overrides[JWTBearerTIPathDep.dependency] = always_allow
238243
self._app.dependency_overrides[JWTRefresherDep.dependency] = always_allow
239244
self._app.dependency_overrides[has_connection_access] = always_allow
240245
self._app.dependency_overrides[has_variable_access] = always_allow

airflow-core/src/airflow/api_fastapi/execution_api/deps.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ async def __call__( # type: ignore[override]
9696

9797
JWTBearerDep: TIToken = Depends(JWTBearer())
9898

99+
# This checks that the UUID in the url matches the one in the token for us.
100+
JWTBearerTIPathDep = Depends(JWTBearer(path_param_name="task_instance_id"))
101+
99102

100103
class JWTReissuer:
101104
"""Re-issue JWTs to requests when they are about to run out."""

airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import structlog
2727
from cadwyn import VersionedAPIRouter
28-
from fastapi import Body, Depends, HTTPException, Query, status
28+
from fastapi import Body, HTTPException, Query, status
2929
from pydantic import JsonValue
3030
from sqlalchemy import func, or_, tuple_, update
3131
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
@@ -50,7 +50,7 @@
5050
TISuccessStatePayload,
5151
TITerminalStatePayload,
5252
)
53-
from airflow.api_fastapi.execution_api.deps import JWTBearer
53+
from airflow.api_fastapi.execution_api.deps import JWTBearerTIPathDep
5454
from airflow.models.dagbag import DagBag
5555
from airflow.models.dagrun import DagRun as DR
5656
from airflow.models.taskinstance import TaskInstance as TI, _stop_remaining_tasks
@@ -70,7 +70,7 @@
7070
ti_id_router = VersionedAPIRouter(
7171
dependencies=[
7272
# This checks that the UUID in the url matches the one in the token for us.
73-
Depends(JWTBearer(path_param_name="task_instance_id")),
73+
JWTBearerTIPathDep
7474
]
7575
)
7676

airflow-core/src/airflow/cli/commands/dag_command.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,6 @@ def dag_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> No
644644
run_conf=run_conf,
645645
use_executor=use_executor,
646646
mark_success_pattern=mark_success_pattern,
647-
session=session,
648647
)
649648
show_dagrun = args.show_dagrun
650649
imgcat = args.imgcat_dagrun

airflow-core/src/airflow/cli/commands/task_command.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@
3333
from airflow.cli.utils import fetch_dag_run_from_run_id_or_logical_date_string
3434
from airflow.exceptions import DagRunNotFound, TaskDeferred, TaskInstanceNotFound
3535
from airflow.models import TaskInstance
36-
from airflow.models.dag import DAG, _run_inline_trigger
36+
from airflow.models.dag import DAG
3737
from airflow.models.dagrun import DagRun
38+
from airflow.sdk.definitions.dag import _run_inline_trigger
3839
from airflow.sdk.definitions.param import ParamsDict
3940
from airflow.sdk.execution_time.secrets_masker import RedactedIO
4041
from airflow.ti_deps.dep_context import DepContext

airflow-core/src/airflow/dag_processing/processor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,11 @@ def _execute_dag_callbacks(dagbag: DagBag, request: DagCallbackRequest, log: Fil
161161

162162
callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
163163
# TODO:We need a proper context object!
164-
context: Context = {} # type: ignore[assignment]
164+
context: Context = { # type: ignore[assignment]
165+
"dag": dag,
166+
"run_id": request.run_id,
167+
"reason": request.msg,
168+
}
165169

166170
for callback in callbacks:
167171
log.info(

0 commit comments

Comments
 (0)