Skip to content

Commit 833e338

Browse files
authored
Fix four bugs in StackdriverTaskHandler (#13784)
1 parent d65376c commit 833e338

File tree

5 files changed

+95
-45
lines changed

5 files changed

+95
-45
lines changed

airflow/providers/google/cloud/log/stackdriver_task_handler.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __init__(
9999
self.resource: Resource = resource
100100
self.labels: Optional[Dict[str, str]] = labels
101101
self.task_instance_labels: Optional[Dict[str, str]] = {}
102+
self.task_instance_hostname = 'default-hostname'
102103

103104
@cached_property
104105
def _client(self) -> gcp_logging.Client:
@@ -146,10 +147,11 @@ def set_context(self, task_instance: TaskInstance) -> None:
146147
:type task_instance: :class:`airflow.models.TaskInstance`
147148
"""
148149
self.task_instance_labels = self._task_instance_to_labels(task_instance)
150+
self.task_instance_hostname = task_instance.hostname
149151

150152
def read(
151153
self, task_instance: TaskInstance, try_number: Optional[int] = None, metadata: Optional[Dict] = None
152-
) -> Tuple[List[str], List[Dict]]:
154+
) -> Tuple[List[Tuple[Tuple[str, str]]], List[Dict[str, str]]]:
153155
"""
154156
Read logs of given task instance from Stackdriver logging.
155157
@@ -160,12 +162,14 @@ def read(
160162
:type try_number: Optional[int]
161163
:param metadata: log metadata. It is used for steaming log reading and auto-tailing.
162164
:type metadata: Dict
163-
:return: a tuple of list of logs and list of metadata
164-
:rtype: Tuple[List[str], List[Dict]]
165+
:return: a tuple of (
166+
list of (one element tuple with two element tuple - hostname and logs)
167+
and list of metadata)
168+
:rtype: Tuple[List[Tuple[Tuple[str, str]]], List[Dict[str, str]]]
165169
"""
166170
if try_number is not None and try_number < 1:
167-
logs = [f"Error fetching the logs. Try number {try_number} is invalid."]
168-
return logs, [{"end_of_log": "true"}]
171+
logs = f"Error fetching the logs. Try number {try_number} is invalid."
172+
return [((self.task_instance_hostname, logs),)], [{"end_of_log": "true"}]
169173

170174
if not metadata:
171175
metadata = {}
@@ -188,7 +192,7 @@ def read(
188192
if next_page_token:
189193
new_metadata['next_page_token'] = next_page_token
190194

191-
return [messages], [new_metadata]
195+
return [((self.task_instance_hostname, messages),)], [new_metadata]
192196

193197
def _prepare_log_filter(self, ti_labels: Dict[str, str]) -> str:
194198
"""
@@ -252,6 +256,8 @@ def _read_logs(
252256
log_filter=log_filter, page_token=next_page_token
253257
)
254258
messages.append(new_messages)
259+
if not messages:
260+
break
255261

256262
end_of_log = True
257263
next_page_token = None
@@ -271,7 +277,9 @@ def _read_single_logs_page(self, log_filter: str, page_token: Optional[str] = No
271277
:return: Downloaded logs and next page token
272278
:rtype: Tuple[str, str]
273279
"""
274-
entries = self._client.list_entries(filter_=log_filter, page_token=page_token)
280+
entries = self._client.list_entries(
281+
filter_=log_filter, page_token=page_token, order_by='timestamp asc', page_size=1000
282+
)
275283
page = next(entries.pages)
276284
next_page_token = entries.next_page_token
277285
messages = []
@@ -331,3 +339,6 @@ def get_external_log_url(self, task_instance: TaskInstance, try_number: int) ->
331339

332340
url = f"{self.LOG_VIEWER_BASE_URL}?{urlencode(url_query_string)}"
333341
return url
342+
343+
def close(self) -> None:
344+
self._transport.flush()

airflow/utils/log/log_reader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717

1818
import logging
19-
from typing import Any, Dict, Iterator, List, Optional, Tuple
19+
from typing import Dict, Iterator, List, Optional, Tuple
2020

2121
from cached_property import cached_property
2222

@@ -31,7 +31,7 @@ class TaskLogReader:
3131

3232
def read_log_chunks(
3333
self, ti: TaskInstance, try_number: Optional[int], metadata
34-
) -> Tuple[List[str], Dict[str, Any]]:
34+
) -> Tuple[List[Tuple[Tuple[str, str]]], Dict[str, str]]:
3535
"""
3636
Reads chunks of Task Instance logs.
3737
@@ -42,7 +42,7 @@ def read_log_chunks(
4242
:type try_number: Optional[int]
4343
:param metadata: A dictionary containing information about how to read the task log
4444
:type metadata: dict
45-
:rtype: Tuple[List[str], Dict[str, Any]]
45+
:rtype: Tuple[List[Tuple[Tuple[str, str]]], Dict[str, str]]
4646
4747
The following is an example of how to use this method to read log:
4848

tests/cli/commands/test_info_command.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import contextlib
1919
import importlib
2020
import io
21+
import logging
2122
import os
2223
import unittest
2324
from unittest import mock
@@ -129,6 +130,8 @@ def test_should_read_logging_configuration(self):
129130
assert "stackdriver" in text
130131

131132
def tearDown(self) -> None:
133+
for handler_ref in logging._handlerList[:]:
134+
logging._removeHandlerRef(handler_ref)
132135
importlib.reload(airflow_local_settings)
133136
configure_logging()
134137

tests/providers/google/cloud/log/test_stackdriver_task_handler.py

Lines changed: 65 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,21 @@ def _create_list_response(messages, token):
3535
return mock.MagicMock(pages=(n for n in [page]), next_page_token=token)
3636

3737

38+
def _remove_stackdriver_handlers():
39+
for handler_ref in reversed(logging._handlerList[:]):
40+
handler = handler_ref()
41+
if not isinstance(handler, StackdriverTaskHandler):
42+
continue
43+
logging._removeHandlerRef(handler_ref)
44+
del handler
45+
46+
3847
class TestStackdriverLoggingHandlerStandalone(unittest.TestCase):
3948
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
4049
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client')
4150
def test_should_pass_message_to_client(self, mock_client, mock_get_creds_and_project_id):
51+
self.addCleanup(_remove_stackdriver_handlers)
52+
4253
mock_get_creds_and_project_id.return_value = ('creds', 'project_id')
4354

4455
transport_type = mock.MagicMock()
@@ -69,6 +80,7 @@ def setUp(self) -> None:
6980
self.ti.try_number = 1
7081
self.ti.state = State.RUNNING
7182
self.addCleanup(self.dag.clear)
83+
self.addCleanup(_remove_stackdriver_handlers)
7284

7385
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
7486
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client')
@@ -128,14 +140,18 @@ def test_should_read_logs_for_all_try(self, mock_client, mock_get_creds_and_proj
128140

129141
logs, metadata = self.stackdriver_task_handler.read(self.ti)
130142
mock_client.return_value.list_entries.assert_called_once_with(
131-
filter_='resource.type="global"\n'
132-
'logName="projects/asf-project/logs/airflow"\n'
133-
'labels.task_id="task_for_testing_file_log_handler"\n'
134-
'labels.dag_id="dag_for_testing_file_task_handler"\n'
135-
'labels.execution_date="2016-01-01T00:00:00+00:00"',
143+
filter_=(
144+
'resource.type="global"\n'
145+
'logName="projects/asf-project/logs/airflow"\n'
146+
'labels.task_id="task_for_testing_file_log_handler"\n'
147+
'labels.dag_id="dag_for_testing_file_task_handler"\n'
148+
'labels.execution_date="2016-01-01T00:00:00+00:00"'
149+
),
150+
order_by='timestamp asc',
151+
page_size=1000,
136152
page_token=None,
137153
)
138-
assert ['MSG1\nMSG2'] == logs
154+
assert [(('default-hostname', 'MSG1\nMSG2'),)] == logs
139155
assert [{'end_of_log': True}] == metadata
140156

141157
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
@@ -149,14 +165,18 @@ def test_should_read_logs_for_task_with_quote(self, mock_client, mock_get_creds_
149165
self.ti.task_id = "K\"OT"
150166
logs, metadata = self.stackdriver_task_handler.read(self.ti)
151167
mock_client.return_value.list_entries.assert_called_once_with(
152-
filter_='resource.type="global"\n'
153-
'logName="projects/asf-project/logs/airflow"\n'
154-
'labels.task_id="K\\"OT"\n'
155-
'labels.dag_id="dag_for_testing_file_task_handler"\n'
156-
'labels.execution_date="2016-01-01T00:00:00+00:00"',
168+
filter_=(
169+
'resource.type="global"\n'
170+
'logName="projects/asf-project/logs/airflow"\n'
171+
'labels.task_id="K\\"OT"\n'
172+
'labels.dag_id="dag_for_testing_file_task_handler"\n'
173+
'labels.execution_date="2016-01-01T00:00:00+00:00"'
174+
),
175+
order_by='timestamp asc',
176+
page_size=1000,
157177
page_token=None,
158178
)
159-
assert ['MSG1\nMSG2'] == logs
179+
assert [(('default-hostname', 'MSG1\nMSG2'),)] == logs
160180
assert [{'end_of_log': True}] == metadata
161181

162182
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
@@ -170,15 +190,19 @@ def test_should_read_logs_for_single_try(self, mock_client, mock_get_creds_and_p
170190

171191
logs, metadata = self.stackdriver_task_handler.read(self.ti, 3)
172192
mock_client.return_value.list_entries.assert_called_once_with(
173-
filter_='resource.type="global"\n'
174-
'logName="projects/asf-project/logs/airflow"\n'
175-
'labels.task_id="task_for_testing_file_log_handler"\n'
176-
'labels.dag_id="dag_for_testing_file_task_handler"\n'
177-
'labels.execution_date="2016-01-01T00:00:00+00:00"\n'
178-
'labels.try_number="3"',
193+
filter_=(
194+
'resource.type="global"\n'
195+
'logName="projects/asf-project/logs/airflow"\n'
196+
'labels.task_id="task_for_testing_file_log_handler"\n'
197+
'labels.dag_id="dag_for_testing_file_task_handler"\n'
198+
'labels.execution_date="2016-01-01T00:00:00+00:00"\n'
199+
'labels.try_number="3"'
200+
),
201+
order_by='timestamp asc',
202+
page_size=1000,
179203
page_token=None,
180204
)
181-
assert ['MSG1\nMSG2'] == logs
205+
assert [(('default-hostname', 'MSG1\nMSG2'),)] == logs
182206
assert [{'end_of_log': True}] == metadata
183207

184208
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
@@ -190,14 +214,18 @@ def test_should_read_logs_with_pagination(self, mock_client, mock_get_creds_and_
190214
]
191215
mock_get_creds_and_project_id.return_value = ('creds', 'project_id')
192216
logs, metadata1 = self.stackdriver_task_handler.read(self.ti, 3)
193-
mock_client.return_value.list_entries.assert_called_once_with(filter_=mock.ANY, page_token=None)
194-
assert ['MSG1\nMSG2'] == logs
217+
mock_client.return_value.list_entries.assert_called_once_with(
218+
filter_=mock.ANY, order_by='timestamp asc', page_size=1000, page_token=None
219+
)
220+
assert [(('default-hostname', 'MSG1\nMSG2'),)] == logs
195221
assert [{'end_of_log': False, 'next_page_token': 'TOKEN1'}] == metadata1
196222

197223
mock_client.return_value.list_entries.return_value.next_page_token = None
198224
logs, metadata2 = self.stackdriver_task_handler.read(self.ti, 3, metadata1[0])
199-
mock_client.return_value.list_entries.assert_called_with(filter_=mock.ANY, page_token="TOKEN1")
200-
assert ['MSG3\nMSG4'] == logs
225+
mock_client.return_value.list_entries.assert_called_with(
226+
filter_=mock.ANY, order_by='timestamp asc', page_size=1000, page_token="TOKEN1"
227+
)
228+
assert [(('default-hostname', 'MSG3\nMSG4'),)] == logs
201229
assert [{'end_of_log': True}] == metadata2
202230

203231
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
@@ -211,7 +239,7 @@ def test_should_read_logs_with_download(self, mock_client, mock_get_creds_and_pr
211239

212240
logs, metadata1 = self.stackdriver_task_handler.read(self.ti, 3, {'download_logs': True})
213241

214-
assert ['MSG1\nMSG2\nMSG3\nMSG4'] == logs
242+
assert [(('default-hostname', 'MSG1\nMSG2\nMSG3\nMSG4'),)] == logs
215243
assert [{'end_of_log': True}] == metadata1
216244

217245
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
@@ -240,17 +268,21 @@ def test_should_read_logs_with_custom_resources(self, mock_client, mock_get_cred
240268

241269
logs, metadata = self.stackdriver_task_handler.read(self.ti)
242270
mock_client.return_value.list_entries.assert_called_once_with(
243-
filter_='resource.type="cloud_composer_environment"\n'
244-
'logName="projects/asf-project/logs/airflow"\n'
245-
'resource.labels."environment.name"="test-instancce"\n'
246-
'resource.labels.location="europpe-west-3"\n'
247-
'resource.labels.project_id="asf-project"\n'
248-
'labels.task_id="task_for_testing_file_log_handler"\n'
249-
'labels.dag_id="dag_for_testing_file_task_handler"\n'
250-
'labels.execution_date="2016-01-01T00:00:00+00:00"',
271+
filter_=(
272+
'resource.type="cloud_composer_environment"\n'
273+
'logName="projects/asf-project/logs/airflow"\n'
274+
'resource.labels."environment.name"="test-instancce"\n'
275+
'resource.labels.location="europpe-west-3"\n'
276+
'resource.labels.project_id="asf-project"\n'
277+
'labels.task_id="task_for_testing_file_log_handler"\n'
278+
'labels.dag_id="dag_for_testing_file_task_handler"\n'
279+
'labels.execution_date="2016-01-01T00:00:00+00:00"'
280+
),
281+
order_by='timestamp asc',
282+
page_size=1000,
251283
page_token=None,
252284
)
253-
assert ['TEXT\nTEXT'] == logs
285+
assert [(('default-hostname', 'TEXT\nTEXT'),)] == logs
254286
assert [{'end_of_log': True}] == metadata
255287

256288
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')

tests/providers/google/cloud/log/test_stackdriver_task_handler_system.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_should_support_key_auth(self, session):
6262
assert 0 == subprocess.Popen(["airflow", "scheduler", "--num-runs", "1"]).wait()
6363
ti = session.query(TaskInstance).filter(TaskInstance.task_id == "create_entry_group").first()
6464

65-
self.assert_remote_logs("INFO - Task exited with return code 0", ti)
65+
self.assert_remote_logs("terminated with exit code 0", ti)
6666

6767
@provide_session
6868
def test_should_support_adc(self, session):
@@ -78,7 +78,7 @@ def test_should_support_adc(self, session):
7878
assert 0 == subprocess.Popen(["airflow", "scheduler", "--num-runs", "1"]).wait()
7979
ti = session.query(TaskInstance).filter(TaskInstance.task_id == "create_entry_group").first()
8080

81-
self.assert_remote_logs("INFO - Task exited with return code 0", ti)
81+
self.assert_remote_logs("terminated with exit code 0", ti)
8282

8383
def assert_remote_logs(self, expected_message, ti):
8484
with provide_gcp_context(GCP_STACKDRIVER), conf_vars(
@@ -94,4 +94,8 @@ def assert_remote_logs(self, expected_message, ti):
9494

9595
task_log_reader = TaskLogReader()
9696
logs = "\n".join(task_log_reader.read_log_stream(ti, try_number=None, metadata={}))
97+
# Preview content
98+
print("=" * 80)
99+
print(logs)
100+
print("=" * 80)
97101
assert expected_message in logs

0 commit comments

Comments
 (0)