|
57 | 57 |
|
58 | 58 | from airflow.exceptions import AirflowException |
59 | 59 | from airflow.providers.common.sql.hooks.sql import DbApiHook |
| 60 | +from airflow.providers.google.cloud.utils.bigquery import bq_cast |
60 | 61 | from airflow.providers.google.common.consts import CLIENT_INFO |
61 | 62 | from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook, get_field |
62 | 63 | from airflow.utils.helpers import convert_camel_to_snake |
@@ -2740,7 +2741,7 @@ def next(self) -> list | None: |
2740 | 2741 | rows = query_results["rows"] |
2741 | 2742 |
|
2742 | 2743 | for dict_row in rows: |
2743 | | - typed_row = [_bq_cast(vs["v"], col_types[idx]) for idx, vs in enumerate(dict_row["f"])] |
| 2744 | + typed_row = [bq_cast(vs["v"], col_types[idx]) for idx, vs in enumerate(dict_row["f"])] |
2744 | 2745 | self.buffer.append(typed_row) |
2745 | 2746 |
|
2746 | 2747 | if not self.page_token: |
@@ -2845,25 +2846,6 @@ def _escape(s: str) -> str: |
2845 | 2846 | return e |
2846 | 2847 |
|
2847 | 2848 |
|
2848 | | -def _bq_cast(string_field: str, bq_type: str) -> None | int | float | bool | str: |
2849 | | - """ |
2850 | | - Helper method that casts a BigQuery row to the appropriate data types. |
2851 | | - This is useful because BigQuery returns all fields as strings. |
2852 | | - """ |
2853 | | - if string_field is None: |
2854 | | - return None |
2855 | | - elif bq_type == "INTEGER": |
2856 | | - return int(string_field) |
2857 | | - elif bq_type in ("FLOAT", "TIMESTAMP"): |
2858 | | - return float(string_field) |
2859 | | - elif bq_type == "BOOLEAN": |
2860 | | - if string_field not in ["true", "false"]: |
2861 | | - raise ValueError(f"{string_field} must have value 'true' or 'false'") |
2862 | | - return string_field == "true" |
2863 | | - else: |
2864 | | - return string_field |
2865 | | - |
2866 | | - |
2867 | 2849 | def split_tablename( |
2868 | 2850 | table_input: str, default_project_id: str, var_name: str | None = None |
2869 | 2851 | ) -> tuple[str, str, str]: |
@@ -3070,7 +3052,7 @@ def get_records(self, query_results: dict[str, Any]) -> list[Any]: |
3070 | 3052 | fields = query_results["schema"]["fields"] |
3071 | 3053 | col_types = [field["type"] for field in fields] |
3072 | 3054 | for dict_row in rows: |
3073 | | - typed_row = [_bq_cast(vs["v"], col_types[idx]) for idx, vs in enumerate(dict_row["f"])] |
| 3055 | + typed_row = [bq_cast(vs["v"], col_types[idx]) for idx, vs in enumerate(dict_row["f"])] |
3074 | 3056 | buffer.append(typed_row) |
3075 | 3057 | return buffer |
3076 | 3058 |
|
|
0 commit comments