1919
2020import copy
2121import unittest
22+ from typing import Any , Dict
2223
2324import mock
2425from mock import MagicMock
4142JAR_FILE = 'unitest.jar'
4243JOB_CLASS = 'com.example.UnitTest'
4344PY_OPTIONS = ['-m' ]
44- DATAFLOW_OPTIONS_PY = {
45+ DATAFLOW_VARIABLES_PY = {
4546 'project' : 'test' ,
4647 'staging_location' : 'gs://test/staging' ,
4748 'labels' : {'foo' : 'bar' }
4849}
49- DATAFLOW_OPTIONS_JAVA = {
50+ DATAFLOW_VARIABLES_JAVA = {
5051 'project' : 'test' ,
5152 'stagingLocation' : 'gs://test/staging' ,
5253 'labels' : {'foo' : 'bar' }
5354}
54- DATAFLOW_OPTIONS_TEMPLATE = {
55+ DATAFLOW_VARIABLES_TEMPLATE = {
5556 'project' : 'test' ,
5657 'tempLocation' : 'gs://test/temp' ,
5758 'zone' : 'us-central1-f'
@@ -172,7 +173,7 @@ def test_start_python_dataflow(
172173 dataflowjob_instance = mock_dataflowjob .return_value
173174 dataflowjob_instance .wait_for_done .return_value = None
174175 self .dataflow_hook .start_python_dataflow ( # pylint: disable=no-value-for-parameter
175- job_name = JOB_NAME , variables = DATAFLOW_OPTIONS_PY ,
176+ job_name = JOB_NAME , variables = DATAFLOW_VARIABLES_PY ,
176177 dataflow = PY_FILE , py_options = PY_OPTIONS ,
177178 )
178179 expected_cmd = ["python3" , '-m' , PY_FILE ,
@@ -184,6 +185,36 @@ def test_start_python_dataflow(
184185 self .assertListEqual (sorted (mock_dataflow .call_args [1 ]["cmd" ]),
185186 sorted (expected_cmd ))
186187
188+ @mock .patch (DATAFLOW_STRING .format ('uuid.uuid4' ))
189+ @mock .patch (DATAFLOW_STRING .format ('_DataflowJobsController' ))
190+ @mock .patch (DATAFLOW_STRING .format ('_DataflowRunner' ))
191+ @mock .patch (DATAFLOW_STRING .format ('DataflowHook.get_conn' ))
192+ def test_start_python_dataflow_with_multiple_extra_packages (
193+ self , mock_conn , mock_dataflow , mock_dataflowjob , mock_uuid
194+ ):
195+ mock_uuid .return_value = MOCK_UUID
196+ mock_conn .return_value = None
197+ dataflow_instance = mock_dataflow .return_value
198+ dataflow_instance .wait_for_done .return_value = None
199+ dataflowjob_instance = mock_dataflowjob .return_value
200+ dataflowjob_instance .wait_for_done .return_value = None
201+ variables : Dict [str , Any ] = copy .deepcopy (DATAFLOW_VARIABLES_PY )
202+ variables ['extra-package' ] = ['a.whl' , 'b.whl' ]
203+
204+ self .dataflow_hook .start_python_dataflow ( # pylint: disable=no-value-for-parameter
205+ job_name = JOB_NAME , variables = variables ,
206+ dataflow = PY_FILE , py_options = PY_OPTIONS ,
207+ )
208+ expected_cmd = ["python3" , '-m' , PY_FILE ,
209+ '--extra-package=a.whl' ,
210+ '--extra-package=b.whl' ,
211+ '--region=us-central1' ,
212+ '--runner=DataflowRunner' , '--project=test' ,
213+ '--labels=foo=bar' ,
214+ '--staging_location=gs://test/staging' ,
215+ '--job_name={}-{}' .format (JOB_NAME , MOCK_UUID )]
216+ self .assertListEqual (sorted (mock_dataflow .call_args [1 ]["cmd" ]), sorted (expected_cmd ))
217+
187218 @parameterized .expand ([
188219 ('default_to_python3' , 'python3' ),
189220 ('major_version_2' , 'python2' ),
@@ -205,7 +236,7 @@ def test_start_python_dataflow_with_custom_interpreter(
205236 dataflowjob_instance = mock_dataflowjob .return_value
206237 dataflowjob_instance .wait_for_done .return_value = None
207238 self .dataflow_hook .start_python_dataflow ( # pylint: disable=no-value-for-parameter
208- job_name = JOB_NAME , variables = DATAFLOW_OPTIONS_PY ,
239+ job_name = JOB_NAME , variables = DATAFLOW_VARIABLES_PY ,
209240 dataflow = PY_FILE , py_options = PY_OPTIONS ,
210241 py_interpreter = py_interpreter ,
211242 )
@@ -231,9 +262,39 @@ def test_start_java_dataflow(self, mock_conn,
231262 dataflowjob_instance = mock_dataflowjob .return_value
232263 dataflowjob_instance .wait_for_done .return_value = None
233264 self .dataflow_hook .start_java_dataflow ( # pylint: disable=no-value-for-parameter
234- job_name = JOB_NAME , variables = DATAFLOW_OPTIONS_JAVA ,
265+ job_name = JOB_NAME , variables = DATAFLOW_VARIABLES_JAVA ,
266+ jar = JAR_FILE )
267+ expected_cmd = ['java' , '-jar' , JAR_FILE ,
268+ '--region=us-central1' ,
269+ '--runner=DataflowRunner' , '--project=test' ,
270+ '--stagingLocation=gs://test/staging' ,
271+ '--labels={"foo":"bar"}' ,
272+ '--jobName={}-{}' .format (JOB_NAME , MOCK_UUID )]
273+ self .assertListEqual (sorted (mock_dataflow .call_args [1 ]["cmd" ]),
274+ sorted (expected_cmd ))
275+
276+ @mock .patch (DATAFLOW_STRING .format ('uuid.uuid4' ))
277+ @mock .patch (DATAFLOW_STRING .format ('_DataflowJobsController' ))
278+ @mock .patch (DATAFLOW_STRING .format ('_DataflowRunner' ))
279+ @mock .patch (DATAFLOW_STRING .format ('DataflowHook.get_conn' ))
280+ def test_start_java_dataflow_with_multiple_values_in_variables (
281+ self , mock_conn , mock_dataflow , mock_dataflowjob , mock_uuid
282+ ):
283+ mock_uuid .return_value = MOCK_UUID
284+ mock_conn .return_value = None
285+ dataflow_instance = mock_dataflow .return_value
286+ dataflow_instance .wait_for_done .return_value = None
287+ dataflowjob_instance = mock_dataflowjob .return_value
288+ dataflowjob_instance .wait_for_done .return_value = None
289+ variables : Dict [str , Any ] = copy .deepcopy (DATAFLOW_VARIABLES_JAVA )
290+ variables ['mock-option' ] = ['a.whl' , 'b.whl' ]
291+
292+ self .dataflow_hook .start_java_dataflow ( # pylint: disable=no-value-for-parameter
293+ job_name = JOB_NAME , variables = variables ,
235294 jar = JAR_FILE )
236295 expected_cmd = ['java' , '-jar' , JAR_FILE ,
296+ '--mock-option=a.whl' ,
297+ '--mock-option=b.whl' ,
237298 '--region=us-central1' ,
238299 '--runner=DataflowRunner' , '--project=test' ,
239300 '--stagingLocation=gs://test/staging' ,
@@ -255,7 +316,7 @@ def test_start_java_dataflow_with_job_class(
255316 dataflowjob_instance = mock_dataflowjob .return_value
256317 dataflowjob_instance .wait_for_done .return_value = None
257318 self .dataflow_hook .start_java_dataflow ( # pylint: disable=no-value-for-parameter
258- job_name = JOB_NAME , variables = DATAFLOW_OPTIONS_JAVA ,
319+ job_name = JOB_NAME , variables = DATAFLOW_VARIABLES_JAVA ,
259320 jar = JAR_FILE , job_class = JOB_CLASS )
260321 expected_cmd = ['java' , '-cp' , JAR_FILE , JOB_CLASS ,
261322 '--region=us-central1' ,
@@ -318,11 +379,11 @@ def test_start_template_dataflow(self, mock_conn, mock_controller, mock_uuid):
318379 )
319380 launch_method .return_value .execute .return_value = {"job" : {"id" : TEST_JOB_ID }}
320381 self .dataflow_hook .start_template_dataflow ( # pylint: disable=no-value-for-parameter
321- job_name = JOB_NAME , variables = DATAFLOW_OPTIONS_TEMPLATE , parameters = PARAMETERS ,
382+ job_name = JOB_NAME , variables = DATAFLOW_VARIABLES_TEMPLATE , parameters = PARAMETERS ,
322383 dataflow_template = TEMPLATE ,
323384 )
324385 options_with_region = {'region' : 'us-central1' }
325- options_with_region .update (DATAFLOW_OPTIONS_TEMPLATE )
386+ options_with_region .update (DATAFLOW_VARIABLES_TEMPLATE )
326387 options_with_region_without_project = copy .deepcopy (options_with_region )
327388 del options_with_region_without_project ['project' ]
328389
@@ -355,7 +416,7 @@ def test_start_template_dataflow(self, mock_conn, mock_controller, mock_uuid):
355416 @mock .patch (DATAFLOW_STRING .format ('_DataflowJobsController' ))
356417 @mock .patch (DATAFLOW_STRING .format ('DataflowHook.get_conn' ))
357418 def test_start_template_dataflow_with_runtime_env (self , mock_conn , mock_dataflowjob , mock_uuid ):
358- dataflow_options_template = copy .deepcopy (DATAFLOW_OPTIONS_TEMPLATE )
419+ dataflow_options_template = copy .deepcopy (DATAFLOW_VARIABLES_TEMPLATE )
359420 options_with_runtime_env = copy .deepcopy (RUNTIME_ENV )
360421 options_with_runtime_env .update (dataflow_options_template )
361422
0 commit comments