Skip to content

Commit 9eacf60

Browse files
authored
Fix GCSToBigQueryOperator not respecting schema_obj (#28444)
* Fix GCSToBigQueryOperator not respecting schema_obj
1 parent 032a542 commit 9eacf60

File tree

2 files changed

+120
-4
lines changed

2 files changed

+120
-4
lines changed

airflow/providers/google/cloud/transfers/gcs_to_bigquery.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -313,15 +313,17 @@ def execute(self, context: Context):
313313
self.source_objects if isinstance(self.source_objects, list) else [self.source_objects]
314314
)
315315
source_uris = [f"gs://{self.bucket}/{source_object}" for source_object in self.source_objects]
316-
if not self.schema_fields:
316+
317+
if not self.schema_fields and self.schema_object and self.source_format != "DATASTORE_BACKUP":
317318
gcs_hook = GCSHook(
318319
gcp_conn_id=self.gcp_conn_id,
319320
delegate_to=self.delegate_to,
320321
impersonation_chain=self.impersonation_chain,
321322
)
322-
if self.schema_object and self.source_format != "DATASTORE_BACKUP":
323-
schema_fields = json.loads(gcs_hook.download(self.bucket, self.schema_object).decode("utf-8"))
324-
self.log.info("Autodetected fields from schema object: %s", schema_fields)
323+
self.schema_fields = json.loads(
324+
gcs_hook.download(self.schema_object_bucket, self.schema_object).decode("utf-8")
325+
)
326+
self.log.info("Autodetected fields from schema object: %s", self.schema_fields)
325327

326328
if self.external_table:
327329
self.log.info("Creating a new BigQuery table for storing data...")

tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# under the License.
1818
from __future__ import annotations
1919

20+
import json
2021
import unittest
2122
from unittest import mock
2223
from unittest.mock import MagicMock, call
@@ -51,6 +52,8 @@
5152
{"name": "id", "type": "INTEGER", "mode": "NULLABLE"},
5253
{"name": "name", "type": "STRING", "mode": "NULLABLE"},
5354
]
55+
SCHEMA_BUCKET = "test-schema-bucket"
56+
SCHEMA_OBJECT = "test/schema/schema.json"
5457
TEST_SOURCE_OBJECTS = ["test/objects/test.csv"]
5558
TEST_SOURCE_OBJECTS_AS_STRING = "test/objects/test.csv"
5659
LABELS = {"k1": "v1"}
@@ -675,6 +678,117 @@ def test_source_objs_as_string_without_external_table_should_execute_successfull
675678

676679
hook.return_value.insert_job.assert_has_calls(calls)
677680

681+
@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.GCSHook")
682+
@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
683+
def test_schema_obj_external_table_should_execute_successfully(self, bq_hook, gcs_hook):
684+
bq_hook.return_value.insert_job.side_effect = [
685+
MagicMock(job_id=pytest.real_job_id, error_result=False),
686+
pytest.real_job_id,
687+
]
688+
bq_hook.return_value.generate_job_id.return_value = pytest.real_job_id
689+
bq_hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
690+
gcs_hook.return_value.download.return_value = bytes(json.dumps(SCHEMA_FIELDS), "utf-8")
691+
operator = GCSToBigQueryOperator(
692+
task_id=TASK_ID,
693+
bucket=TEST_BUCKET,
694+
source_objects=TEST_SOURCE_OBJECTS,
695+
schema_object_bucket=SCHEMA_BUCKET,
696+
schema_object=SCHEMA_OBJECT,
697+
write_disposition=WRITE_DISPOSITION,
698+
destination_project_dataset_table=TEST_EXPLICIT_DEST,
699+
external_table=True,
700+
)
701+
702+
operator.execute(context=MagicMock())
703+
704+
bq_hook.return_value.create_empty_table.assert_called_once_with(
705+
table_resource={
706+
"tableReference": {"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
707+
"labels": None,
708+
"description": None,
709+
"externalDataConfiguration": {
710+
"source_uris": [f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
711+
"source_format": "CSV",
712+
"maxBadRecords": 0,
713+
"autodetect": True,
714+
"compression": "NONE",
715+
"csvOptions": {
716+
"fieldDelimeter": ",",
717+
"skipLeadingRows": None,
718+
"quote": None,
719+
"allowQuotedNewlines": False,
720+
"allowJaggedRows": False,
721+
},
722+
},
723+
"location": None,
724+
"encryptionConfiguration": None,
725+
"schema": {"fields": SCHEMA_FIELDS},
726+
}
727+
)
728+
gcs_hook.return_value.download.assert_called_once_with(SCHEMA_BUCKET, SCHEMA_OBJECT)
729+
730+
@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.GCSHook")
731+
@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
732+
def test_schema_obj_without_external_table_should_execute_successfully(self, bq_hook, gcs_hook):
733+
bq_hook.return_value.insert_job.side_effect = [
734+
MagicMock(job_id=pytest.real_job_id, error_result=False),
735+
pytest.real_job_id,
736+
]
737+
bq_hook.return_value.generate_job_id.return_value = pytest.real_job_id
738+
bq_hook.return_value.split_tablename.return_value = (PROJECT_ID, DATASET, TABLE)
739+
gcs_hook.return_value.download.return_value = bytes(json.dumps(SCHEMA_FIELDS), "utf-8")
740+
741+
operator = GCSToBigQueryOperator(
742+
task_id=TASK_ID,
743+
bucket=TEST_BUCKET,
744+
source_objects=TEST_SOURCE_OBJECTS,
745+
schema_object_bucket=SCHEMA_BUCKET,
746+
schema_object=SCHEMA_OBJECT,
747+
destination_project_dataset_table=TEST_EXPLICIT_DEST,
748+
write_disposition=WRITE_DISPOSITION,
749+
external_table=False,
750+
)
751+
752+
operator.execute(context=MagicMock())
753+
754+
calls = [
755+
call(
756+
configuration={
757+
"load": dict(
758+
autodetect=True,
759+
createDisposition="CREATE_IF_NEEDED",
760+
destinationTable={"projectId": PROJECT_ID, "datasetId": DATASET, "tableId": TABLE},
761+
destinationTableProperties={
762+
"description": None,
763+
"labels": None,
764+
},
765+
sourceFormat="CSV",
766+
skipLeadingRows=None,
767+
sourceUris=[f"gs://{TEST_BUCKET}/{TEST_SOURCE_OBJECTS_AS_STRING}"],
768+
writeDisposition=WRITE_DISPOSITION,
769+
ignoreUnknownValues=False,
770+
allowQuotedNewlines=False,
771+
encoding="UTF-8",
772+
schema={"fields": SCHEMA_FIELDS},
773+
allowJaggedRows=False,
774+
fieldDelimiter=",",
775+
maxBadRecords=0,
776+
quote=None,
777+
schemaUpdateOptions=(),
778+
),
779+
},
780+
project_id=bq_hook.return_value.project_id,
781+
location=None,
782+
job_id=pytest.real_job_id,
783+
timeout=None,
784+
retry=DEFAULT_RETRY,
785+
nowait=True,
786+
),
787+
]
788+
789+
bq_hook.return_value.insert_job.assert_has_calls(calls)
790+
gcs_hook.return_value.download.assert_called_once_with(SCHEMA_BUCKET, SCHEMA_OBJECT)
791+
678792
@mock.patch("airflow.providers.google.cloud.transfers.gcs_to_bigquery.BigQueryHook")
679793
def test_all_fields_should_be_present(self, hook):
680794
hook.return_value.insert_job.side_effect = [

0 commit comments

Comments
 (0)