Skip to content

Commit 875387a

Browse files
authored
Refactor unneeded jumps in providers (#33833)
1 parent 0e1c106 commit 875387a

File tree

17 files changed

+108
-142
lines changed

17 files changed

+108
-142
lines changed

airflow/providers/amazon/aws/hooks/datasync.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -301,25 +301,18 @@ def wait_for_task_execution(self, task_execution_arn: str, max_iterations: int =
301301
if not task_execution_arn:
302302
raise AirflowBadRequest("task_execution_arn not specified")
303303

304-
status = None
305-
iterations = max_iterations
306-
while status is None or status in self.TASK_EXECUTION_INTERMEDIATE_STATES:
304+
for _ in range(max_iterations):
307305
task_execution = self.get_conn().describe_task_execution(TaskExecutionArn=task_execution_arn)
308306
status = task_execution["Status"]
309307
self.log.info("status=%s", status)
310-
iterations -= 1
311-
if status in self.TASK_EXECUTION_FAILURE_STATES:
312-
break
313308
if status in self.TASK_EXECUTION_SUCCESS_STATES:
314-
break
315-
if iterations <= 0:
316-
break
309+
return True
310+
elif status in self.TASK_EXECUTION_FAILURE_STATES:
311+
return False
312+
elif status is None or status in self.TASK_EXECUTION_INTERMEDIATE_STATES:
313+
time.sleep(self.wait_interval_seconds)
314+
else:
315+
raise AirflowException(f"Unknown status: {status}") # Should never happen
317316
time.sleep(self.wait_interval_seconds)
318-
319-
if status in self.TASK_EXECUTION_SUCCESS_STATES:
320-
return True
321-
if status in self.TASK_EXECUTION_FAILURE_STATES:
322-
return False
323-
if iterations <= 0:
317+
else:
324318
raise AirflowTaskTimeout("Max iterations exceeded!")
325-
raise AirflowException(f"Unknown status: {status}") # Should never happen

airflow/providers/amazon/aws/hooks/sagemaker.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -252,12 +252,12 @@ def multi_stream_iter(self, log_group: str, streams: list, positions=None) -> Ge
252252
]
253253
events: list[Any | None] = []
254254
for event_stream in event_iters:
255-
if not event_stream:
256-
events.append(None)
257-
continue
258-
try:
259-
events.append(next(event_stream))
260-
except StopIteration:
255+
if event_stream:
256+
try:
257+
events.append(next(event_stream))
258+
except StopIteration:
259+
events.append(None)
260+
else:
261261
events.append(None)
262262

263263
while any(events):

airflow/providers/amazon/aws/sensors/sqs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,12 @@ def poke(self, context: Context):
204204

205205
if "Successful" not in response:
206206
raise AirflowException(f"Delete SQS Messages failed {response} for messages {messages}")
207-
if not message_batch:
207+
if message_batch:
208+
context["ti"].xcom_push(key="messages", value=message_batch)
209+
return True
210+
else:
208211
return False
209212

210-
context["ti"].xcom_push(key="messages", value=message_batch)
211-
return True
212-
213213
@deprecated(reason="use `hook` property instead.")
214214
def get_hook(self) -> SqsHook:
215215
"""Create and return an SqsHook."""

airflow/providers/amazon/aws/utils/sqs.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,9 @@ def filter_messages_jsonpath(messages, message_filtering_match_values, message_f
7979
# Body is a string, deserialize to an object and then parse
8080
body = json.loads(body)
8181
results = jsonpath_expr.find(body)
82-
if not results:
83-
continue
84-
if message_filtering_match_values is None:
82+
if results and (
83+
message_filtering_match_values is None
84+
or any(result.value in message_filtering_match_values for result in results)
85+
):
8586
filtered_messages.append(message)
86-
continue
87-
for result in results:
88-
if result.value in message_filtering_match_values:
89-
filtered_messages.append(message)
90-
break
9187
return filtered_messages

airflow/providers/cncf/kubernetes/utils/delete_from.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,7 @@ def delete_from_yaml(
8181
**kwargs,
8282
):
8383
for yml_document in yaml_objects:
84-
if yml_document is None:
85-
continue
86-
else:
84+
if yml_document is not None:
8785
delete_from_dict(
8886
k8s_client=k8s_client,
8987
data=yml_document,

airflow/providers/google/cloud/hooks/bigquery.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2206,27 +2206,25 @@ def run_query(
22062206
if param_name == "schemaUpdateOptions" and param:
22072207
self.log.info("Adding experimental 'schemaUpdateOptions': %s", schema_update_options)
22082208

2209-
if param_name != "destinationTable":
2210-
continue
2211-
2212-
for key in ["projectId", "datasetId", "tableId"]:
2213-
if key not in configuration["query"]["destinationTable"]:
2214-
raise ValueError(
2215-
"Not correct 'destinationTable' in "
2216-
"api_resource_configs. 'destinationTable' "
2217-
"must be a dict with {'projectId':'', "
2218-
"'datasetId':'', 'tableId':''}"
2209+
if param_name == "destinationTable":
2210+
for key in ["projectId", "datasetId", "tableId"]:
2211+
if key not in configuration["query"]["destinationTable"]:
2212+
raise ValueError(
2213+
"Not correct 'destinationTable' in "
2214+
"api_resource_configs. 'destinationTable' "
2215+
"must be a dict with {'projectId':'', "
2216+
"'datasetId':'', 'tableId':''}"
2217+
)
2218+
else:
2219+
configuration["query"].update(
2220+
{
2221+
"allowLargeResults": allow_large_results,
2222+
"flattenResults": flatten_results,
2223+
"writeDisposition": write_disposition,
2224+
"createDisposition": create_disposition,
2225+
}
22192226
)
22202227

2221-
configuration["query"].update(
2222-
{
2223-
"allowLargeResults": allow_large_results,
2224-
"flattenResults": flatten_results,
2225-
"writeDisposition": write_disposition,
2226-
"createDisposition": create_disposition,
2227-
}
2228-
)
2229-
22302228
if (
22312229
"useLegacySql" in configuration["query"]
22322230
and configuration["query"]["useLegacySql"]

airflow/providers/google/cloud/hooks/datafusion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -371,12 +371,12 @@ def delete_pipeline(
371371
self._check_response_status_and_data(
372372
response, f"Deleting a pipeline failed with code {response.status}: {response.data}"
373373
)
374-
if response.status == 200:
375-
break
376374
except ConflictException as exc:
377375
self.log.info(exc)
378376
sleep(time_to_wait)
379-
continue
377+
else:
378+
if response.status == 200:
379+
break
380380

381381
def list_pipelines(
382382
self,

airflow/providers/google/cloud/hooks/gcs.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,6 @@ def download(
361361
# Wait with exponential backoff scheme before retrying.
362362
timeout_seconds = 2 ** (num_file_attempts - 1)
363363
time.sleep(timeout_seconds)
364-
continue
365364

366365
def download_as_byte_array(
367366
self,
@@ -508,28 +507,23 @@ def _call_with_retry(f: Callable[[], None]) -> None:
508507
509508
:param f: Callable that should be retried.
510509
"""
511-
num_file_attempts = 0
512-
513-
while num_file_attempts < num_max_attempts:
510+
for attempt in range(1, 1 + num_max_attempts):
514511
try:
515-
num_file_attempts += 1
516512
f()
517-
518513
except GoogleCloudError as e:
519-
if num_file_attempts == num_max_attempts:
514+
if attempt == num_max_attempts:
520515
self.log.error(
521516
"Upload attempt of object: %s from %s has failed. Attempt: %s, max %s.",
522517
object_name,
523518
object_name,
524-
num_file_attempts,
519+
attempt,
525520
num_max_attempts,
526521
)
527522
raise e
528523

529524
# Wait with exponential backoff scheme before retrying.
530-
timeout_seconds = 2 ** (num_file_attempts - 1)
525+
timeout_seconds = 2 ** (attempt - 1)
531526
time.sleep(timeout_seconds)
532-
continue
533527

534528
client = self.get_conn()
535529
bucket = client.bucket(bucket_name, user_project=user_project)

airflow/providers/google/cloud/log/gcs_task_handler.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,7 @@ def gcs_write(self, log, remote_log_location) -> bool:
243243
old_log = blob.download_as_bytes().decode()
244244
log = "\n".join([old_log, log]) if old_log else log
245245
except Exception as e:
246-
if self.no_log_found(e):
247-
pass
248-
else:
246+
if not self.no_log_found(e):
249247
log += self._add_message(
250248
f"Error checking for previous log; if exists, may be overwritten: {e}"
251249
)

airflow/providers/google/cloud/operators/compute.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -174,14 +174,13 @@ def __init__(
174174
def check_body_fields(self) -> None:
175175
required_params = ["machine_type", "disks", "network_interfaces"]
176176
for param in required_params:
177-
if param in self.body:
178-
continue
179-
readable_param = param.replace("_", " ")
180-
raise AirflowException(
181-
f"The body '{self.body}' should contain at least {readable_param} for the new operator "
182-
f"in the '{param}' field. Check (google.cloud.compute_v1.types.Instance) "
183-
f"for more details about body fields description."
184-
)
177+
if param not in self.body:
178+
readable_param = param.replace("_", " ")
179+
raise AirflowException(
180+
f"The body '{self.body}' should contain at least {readable_param} for the new operator "
181+
f"in the '{param}' field. Check (google.cloud.compute_v1.types.Instance) "
182+
f"for more details about body fields description."
183+
)
185184

186185
def _validate_inputs(self) -> None:
187186
super()._validate_inputs()
@@ -915,14 +914,13 @@ def __init__(
915914
def check_body_fields(self) -> None:
916915
required_params = ["machine_type", "disks", "network_interfaces"]
917916
for param in required_params:
918-
if param in self.body["properties"]:
919-
continue
920-
readable_param = param.replace("_", " ")
921-
raise AirflowException(
922-
f"The body '{self.body}' should contain at least {readable_param} for the new operator "
923-
f"in the '{param}' field. Check (google.cloud.compute_v1.types.Instance) "
924-
f"for more details about body fields description."
925-
)
917+
if param not in self.body["properties"]:
918+
readable_param = param.replace("_", " ")
919+
raise AirflowException(
920+
f"The body '{self.body}' should contain at least {readable_param} for the new operator "
921+
f"in the '{param}' field. Check (google.cloud.compute_v1.types.Instance) "
922+
f"for more details about body fields description."
923+
)
926924

927925
def _validate_all_body_fields(self) -> None:
928926
if self._field_validator:
@@ -1500,14 +1498,13 @@ def __init__(
15001498
def check_body_fields(self) -> None:
15011499
required_params = ["base_instance_name", "target_size", "instance_template"]
15021500
for param in required_params:
1503-
if param in self.body:
1504-
continue
1505-
readable_param = param.replace("_", " ")
1506-
raise AirflowException(
1507-
f"The body '{self.body}' should contain at least {readable_param} for the new operator "
1508-
f"in the '{param}' field. Check (google.cloud.compute_v1.types.Instance) "
1509-
f"for more details about body fields description."
1510-
)
1501+
if param not in self.body:
1502+
readable_param = param.replace("_", " ")
1503+
raise AirflowException(
1504+
f"The body '{self.body}' should contain at least {readable_param} for the new operator "
1505+
f"in the '{param}' field. Check (google.cloud.compute_v1.types.Instance) "
1506+
f"for more details about body fields description."
1507+
)
15111508

15121509
def _validate_all_body_fields(self) -> None:
15131510
if self._field_validator:

0 commit comments

Comments
 (0)