1818from __future__ import annotations
1919
2020import json
21+ import os
22+ import platform
23+ import tempfile
2124from unittest import mock
2225from unittest .mock import PropertyMock
2326
2730
2831from airflow .exceptions import AirflowException
2932from airflow .models import Connection
30- from airflow .providers .google .cloud .hooks .cloud_sql import CloudSQLDatabaseHook , CloudSQLHook
33+ from airflow .providers .google .cloud .hooks .cloud_sql import (
34+ CloudSQLDatabaseHook ,
35+ CloudSQLHook ,
36+ CloudSqlProxyRunner ,
37+ )
3138from tests .providers .google .cloud .utils .base_gcp_mock import (
3239 mock_base_gcp_hook_default_project_id ,
3340 mock_base_gcp_hook_no_default_project_id ,
@@ -847,8 +854,12 @@ def test_cloudsql_database_hook_validate_ssl_certs_with_ssl_files_not_readable(
847854 err = ctx .value
848855 assert "must be a readable file" in str (err )
849856
857+ @mock .patch ("airflow.providers.google.cloud.hooks.cloud_sql.gettempdir" )
850858 @mock .patch ("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection" )
851- def test_cloudsql_database_hook_validate_socket_path_length_too_long (self , get_connection ):
859+ def test_cloudsql_database_hook_validate_socket_path_length_too_long (
860+ self , get_connection , gettempdir_mock
861+ ):
862+ gettempdir_mock .return_value = "/tmp"
852863 connection = Connection ()
853864 connection .set_extra (
854865 json .dumps (
@@ -870,8 +881,12 @@ def test_cloudsql_database_hook_validate_socket_path_length_too_long(self, get_c
870881 err = ctx .value
871882 assert "The UNIX socket path length cannot exceed" in str (err )
872883
884+ @mock .patch ("airflow.providers.google.cloud.hooks.cloud_sql.gettempdir" )
873885 @mock .patch ("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection" )
874- def test_cloudsql_database_hook_validate_socket_path_length_not_too_long (self , get_connection ):
886+ def test_cloudsql_database_hook_validate_socket_path_length_not_too_long (
887+ self , get_connection , gettempdir_mock
888+ ):
889+ gettempdir_mock .return_value = "/tmp"
875890 connection = Connection ()
876891 connection .set_extra (
877892 json .dumps (
@@ -1093,7 +1108,7 @@ def test_hook_with_correct_parameters_postgres_proxy_socket(self, get_connection
10931108 hook = CloudSQLDatabaseHook ()
10941109 connection = hook .create_connection ()
10951110 assert "postgres" == connection .conn_type
1096- assert "/tmp" in connection .host
1111+ assert tempfile . gettempdir () in connection .host
10971112 assert "example-project:europe-west1:testdb" in connection .host
10981113 assert connection .port is None
10991114 assert "testdb" == connection .schema
@@ -1166,7 +1181,7 @@ def test_hook_with_correct_parameters_mysql_proxy_socket(self, get_connection):
11661181 connection = hook .create_connection ()
11671182 assert "mysql" == connection .conn_type
11681183 assert "localhost" == connection .host
1169- assert "/tmp" in connection .extra_dejson ["unix_socket" ]
1184+ assert tempfile . gettempdir () in connection .extra_dejson ["unix_socket" ]
11701185 assert "example-project:europe-west1:testdb" in connection .extra_dejson ["unix_socket" ]
11711186 assert connection .port is None
11721187 assert "testdb" == connection .schema
@@ -1185,3 +1200,53 @@ def test_hook_with_correct_parameters_mysql_tcp(self, get_connection):
11851200 assert "127.0.0.1" == connection .host
11861201 assert 3200 != connection .port
11871202 assert "testdb" == connection .schema
1203+
1204+
1205+ def get_processor ():
1206+ processor = os .uname ().machine
1207+ if processor == "x86_64" :
1208+ processor = "amd64"
1209+ return processor
1210+
1211+
1212+ class TestCloudSqlProxyRunner :
1213+ @pytest .mark .parametrize (
1214+ ["version" , "download_url" ],
1215+ [
1216+ (
1217+ "v1.23.0" ,
1218+ "https://storage.googleapis.com/cloudsql-proxy/v1.23.0/cloud_sql_proxy."
1219+ f"{ platform .system ().lower ()} .{ get_processor ()} " ,
1220+ ),
1221+ (
1222+ "v1.23.0-preview.1" ,
1223+ "https://storage.googleapis.com/cloudsql-proxy/v1.23.0-preview.1/cloud_sql_proxy."
1224+ f"{ platform .system ().lower ()} .{ get_processor ()} " ,
1225+ ),
1226+ ],
1227+ )
1228+ def test_cloud_sql_proxy_runner_version_ok (self , version , download_url ):
1229+ runner = CloudSqlProxyRunner (
1230+ path_prefix = "12345678" ,
1231+ instance_specification = "project:us-east-1:instance" ,
1232+ sql_proxy_version = version ,
1233+ )
1234+ assert runner ._get_sql_proxy_download_url () == download_url
1235+
1236+ @pytest .mark .parametrize (
1237+ "version" ,
1238+ [
1239+ "v1.23." ,
1240+ "v1.23.0.." ,
1241+ "v1.23.0\\ " ,
1242+ "\\ " ,
1243+ ],
1244+ )
1245+ def test_cloud_sql_proxy_runner_version_nok (self , version ):
1246+ runner = CloudSqlProxyRunner (
1247+ path_prefix = "12345678" ,
1248+ instance_specification = "project:us-east-1:instance" ,
1249+ sql_proxy_version = version ,
1250+ )
1251+ with pytest .raises (ValueError , match = "The sql_proxy_version should match the regular expression" ):
1252+ runner ._get_sql_proxy_download_url ()
0 commit comments