Skip to content

Commit 5d3a7ee

Browse files
mik-lajturbaszek
andauthored
Allow multiple extra_packages in Dataflow (#8394)
Co-authored-by: Tomek Urbaszek <tomasz.urbaszek@polidea.com>
1 parent c34ba9a commit 5d3a7ee

File tree

3 files changed

+112
-20
lines changed

3 files changed

+112
-20
lines changed

airflow/providers/google/cloud/hooks/dataflow.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -681,14 +681,23 @@ def _build_cmd(variables: Dict, label_formatter: Callable, project_id: str) -> L
681681
"--runner=DataflowRunner",
682682
"--project={}".format(project_id),
683683
]
684-
if variables is not None:
685-
for attr, value in variables.items():
686-
if attr == 'labels':
687-
command += label_formatter(value)
688-
elif value is None or value.__len__() < 1:
689-
command.append("--" + attr)
690-
else:
691-
command.append("--" + attr + "=" + value)
684+
if variables is None:
685+
return command
686+
687+
# The logic of this method should be compatible with Apache Beam:
688+
# https://github.com/apache/beam/blob/b56740f0e8cd80c2873412847d0b336837429fb9/sdks/python/
689+
# apache_beam/options/pipeline_options.py#L230-L251
690+
for attr, value in variables.items():
691+
if attr == 'labels':
692+
command += label_formatter(value)
693+
elif value is None:
694+
command.append(f"--{attr}")
695+
elif isinstance(value, bool) and value:
696+
command.append(f"--{attr}")
697+
elif isinstance(value, list):
698+
command.extend([f"--{attr}={v}" for v in value])
699+
else:
700+
command.append(f"--{attr}={value}")
692701
return command
693702

694703
@_fallback_to_project_id_from_variables

airflow/providers/google/cloud/operators/dataflow.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,18 @@ class DataflowCreateJavaJobOperator(BaseOperator):
9595
:type job_name: str
9696
:param dataflow_default_options: Map of default job options.
9797
:type dataflow_default_options: dict
98-
:param options: Map of job specific options.
98+
:param options: Map of job specific options.The key must be a dictionary.
99+
The value can contain different types:
100+
101+
* If the value is None, the single option - ``--key`` (without value) will be added.
102+
* If the value is False, this option will be skipped
103+
* If the value is True, the single option - ``--key`` (without value) will be added.
104+
* If the value is list, the many options will be added for each key.
105+
If the value is ``['A', 'B']`` and the key is ``key`` then the ``--key=A --key-B`` options
106+
will be left
107+
* Other value types will be replaced with the Python textual representation.
108+
109+
When defining labels (``labels`` option), you can also provide a dictionary.
99110
:type options: dict
100111
:param gcp_conn_id: The connection ID to use connecting to Google Cloud
101112
Platform.
@@ -402,7 +413,18 @@ class DataflowCreatePythonJobOperator(BaseOperator):
402413
:type py_options: list[str]
403414
:param dataflow_default_options: Map of default job options.
404415
:type dataflow_default_options: dict
405-
:param options: Map of job specific options.
416+
:param options: Map of job specific options.The key must be a dictionary.
417+
The value can contain different types:
418+
419+
* If the value is None, the single option - ``--key`` (without value) will be added.
420+
* If the value is False, this option will be skipped
421+
* If the value is True, the single option - ``--key`` (without value) will be added.
422+
* If the value is list, the many options will be added for each key.
423+
If the value is ``['A', 'B']`` and the key is ``key`` then the ``--key=A --key-B`` options
424+
will be left
425+
* Other value types will be replaced with the Python textual representation.
426+
427+
When defining labels (``labels`` option), you can also provide a dictionary.
406428
:type options: dict
407429
:param py_interpreter: Python version of the beam pipeline.
408430
If None, this defaults to the python3.

tests/providers/google/cloud/hooks/test_dataflow.py

Lines changed: 71 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import copy
2121
import unittest
22+
from typing import Any, Dict
2223

2324
import mock
2425
from mock import MagicMock
@@ -41,17 +42,17 @@
4142
JAR_FILE = 'unitest.jar'
4243
JOB_CLASS = 'com.example.UnitTest'
4344
PY_OPTIONS = ['-m']
44-
DATAFLOW_OPTIONS_PY = {
45+
DATAFLOW_VARIABLES_PY = {
4546
'project': 'test',
4647
'staging_location': 'gs://test/staging',
4748
'labels': {'foo': 'bar'}
4849
}
49-
DATAFLOW_OPTIONS_JAVA = {
50+
DATAFLOW_VARIABLES_JAVA = {
5051
'project': 'test',
5152
'stagingLocation': 'gs://test/staging',
5253
'labels': {'foo': 'bar'}
5354
}
54-
DATAFLOW_OPTIONS_TEMPLATE = {
55+
DATAFLOW_VARIABLES_TEMPLATE = {
5556
'project': 'test',
5657
'tempLocation': 'gs://test/temp',
5758
'zone': 'us-central1-f'
@@ -172,7 +173,7 @@ def test_start_python_dataflow(
172173
dataflowjob_instance = mock_dataflowjob.return_value
173174
dataflowjob_instance.wait_for_done.return_value = None
174175
self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter
175-
job_name=JOB_NAME, variables=DATAFLOW_OPTIONS_PY,
176+
job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_PY,
176177
dataflow=PY_FILE, py_options=PY_OPTIONS,
177178
)
178179
expected_cmd = ["python3", '-m', PY_FILE,
@@ -184,6 +185,36 @@ def test_start_python_dataflow(
184185
self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]),
185186
sorted(expected_cmd))
186187

188+
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
189+
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
190+
@mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
191+
@mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
192+
def test_start_python_dataflow_with_multiple_extra_packages(
193+
self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid
194+
):
195+
mock_uuid.return_value = MOCK_UUID
196+
mock_conn.return_value = None
197+
dataflow_instance = mock_dataflow.return_value
198+
dataflow_instance.wait_for_done.return_value = None
199+
dataflowjob_instance = mock_dataflowjob.return_value
200+
dataflowjob_instance.wait_for_done.return_value = None
201+
variables: Dict[str, Any] = copy.deepcopy(DATAFLOW_VARIABLES_PY)
202+
variables['extra-package'] = ['a.whl', 'b.whl']
203+
204+
self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter
205+
job_name=JOB_NAME, variables=variables,
206+
dataflow=PY_FILE, py_options=PY_OPTIONS,
207+
)
208+
expected_cmd = ["python3", '-m', PY_FILE,
209+
'--extra-package=a.whl',
210+
'--extra-package=b.whl',
211+
'--region=us-central1',
212+
'--runner=DataflowRunner', '--project=test',
213+
'--labels=foo=bar',
214+
'--staging_location=gs://test/staging',
215+
'--job_name={}-{}'.format(JOB_NAME, MOCK_UUID)]
216+
self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]), sorted(expected_cmd))
217+
187218
@parameterized.expand([
188219
('default_to_python3', 'python3'),
189220
('major_version_2', 'python2'),
@@ -205,7 +236,7 @@ def test_start_python_dataflow_with_custom_interpreter(
205236
dataflowjob_instance = mock_dataflowjob.return_value
206237
dataflowjob_instance.wait_for_done.return_value = None
207238
self.dataflow_hook.start_python_dataflow( # pylint: disable=no-value-for-parameter
208-
job_name=JOB_NAME, variables=DATAFLOW_OPTIONS_PY,
239+
job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_PY,
209240
dataflow=PY_FILE, py_options=PY_OPTIONS,
210241
py_interpreter=py_interpreter,
211242
)
@@ -231,9 +262,39 @@ def test_start_java_dataflow(self, mock_conn,
231262
dataflowjob_instance = mock_dataflowjob.return_value
232263
dataflowjob_instance.wait_for_done.return_value = None
233264
self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter
234-
job_name=JOB_NAME, variables=DATAFLOW_OPTIONS_JAVA,
265+
job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_JAVA,
266+
jar=JAR_FILE)
267+
expected_cmd = ['java', '-jar', JAR_FILE,
268+
'--region=us-central1',
269+
'--runner=DataflowRunner', '--project=test',
270+
'--stagingLocation=gs://test/staging',
271+
'--labels={"foo":"bar"}',
272+
'--jobName={}-{}'.format(JOB_NAME, MOCK_UUID)]
273+
self.assertListEqual(sorted(mock_dataflow.call_args[1]["cmd"]),
274+
sorted(expected_cmd))
275+
276+
@mock.patch(DATAFLOW_STRING.format('uuid.uuid4'))
277+
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
278+
@mock.patch(DATAFLOW_STRING.format('_DataflowRunner'))
279+
@mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
280+
def test_start_java_dataflow_with_multiple_values_in_variables(
281+
self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid
282+
):
283+
mock_uuid.return_value = MOCK_UUID
284+
mock_conn.return_value = None
285+
dataflow_instance = mock_dataflow.return_value
286+
dataflow_instance.wait_for_done.return_value = None
287+
dataflowjob_instance = mock_dataflowjob.return_value
288+
dataflowjob_instance.wait_for_done.return_value = None
289+
variables: Dict[str, Any] = copy.deepcopy(DATAFLOW_VARIABLES_JAVA)
290+
variables['mock-option'] = ['a.whl', 'b.whl']
291+
292+
self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter
293+
job_name=JOB_NAME, variables=variables,
235294
jar=JAR_FILE)
236295
expected_cmd = ['java', '-jar', JAR_FILE,
296+
'--mock-option=a.whl',
297+
'--mock-option=b.whl',
237298
'--region=us-central1',
238299
'--runner=DataflowRunner', '--project=test',
239300
'--stagingLocation=gs://test/staging',
@@ -255,7 +316,7 @@ def test_start_java_dataflow_with_job_class(
255316
dataflowjob_instance = mock_dataflowjob.return_value
256317
dataflowjob_instance.wait_for_done.return_value = None
257318
self.dataflow_hook.start_java_dataflow( # pylint: disable=no-value-for-parameter
258-
job_name=JOB_NAME, variables=DATAFLOW_OPTIONS_JAVA,
319+
job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_JAVA,
259320
jar=JAR_FILE, job_class=JOB_CLASS)
260321
expected_cmd = ['java', '-cp', JAR_FILE, JOB_CLASS,
261322
'--region=us-central1',
@@ -318,11 +379,11 @@ def test_start_template_dataflow(self, mock_conn, mock_controller, mock_uuid):
318379
)
319380
launch_method.return_value.execute.return_value = {"job": {"id": TEST_JOB_ID}}
320381
self.dataflow_hook.start_template_dataflow( # pylint: disable=no-value-for-parameter
321-
job_name=JOB_NAME, variables=DATAFLOW_OPTIONS_TEMPLATE, parameters=PARAMETERS,
382+
job_name=JOB_NAME, variables=DATAFLOW_VARIABLES_TEMPLATE, parameters=PARAMETERS,
322383
dataflow_template=TEMPLATE,
323384
)
324385
options_with_region = {'region': 'us-central1'}
325-
options_with_region.update(DATAFLOW_OPTIONS_TEMPLATE)
386+
options_with_region.update(DATAFLOW_VARIABLES_TEMPLATE)
326387
options_with_region_without_project = copy.deepcopy(options_with_region)
327388
del options_with_region_without_project['project']
328389

@@ -355,7 +416,7 @@ def test_start_template_dataflow(self, mock_conn, mock_controller, mock_uuid):
355416
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
356417
@mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
357418
def test_start_template_dataflow_with_runtime_env(self, mock_conn, mock_dataflowjob, mock_uuid):
358-
dataflow_options_template = copy.deepcopy(DATAFLOW_OPTIONS_TEMPLATE)
419+
dataflow_options_template = copy.deepcopy(DATAFLOW_VARIABLES_TEMPLATE)
359420
options_with_runtime_env = copy.deepcopy(RUNTIME_ENV)
360421
options_with_runtime_env.update(dataflow_options_template)
361422

0 commit comments

Comments
 (0)