@@ -172,15 +172,7 @@ def add_provide_context_to_python_operator(node: LN, capture: Capture, filename:
172172
173173 def remove_class (qry , class_name ) -> None :
174174 def _remover (node : LN , capture : Capture , filename : Filename ) -> None :
175- if node .type == 300 :
176- for ch in node .post_order ():
177- if isinstance (ch , Leaf ) and ch .value == class_name :
178- if ch .next_sibling and ch .next_sibling .value == "," :
179- ch .next_sibling .remove ()
180- ch .remove ()
181- elif node .type == 311 :
182- node .parent .remove ()
183- else :
175+ if node .type not in (300 , 311 ): # remove only definition
184176 node .remove ()
185177
186178 qry .select_class (class_name ).modify (_remover )
@@ -189,6 +181,10 @@ def _remover(node: LN, capture: Capture, filename: Filename) -> None:
189181 ("airflow.operators.bash" , "airflow.operators.bash_operator" ),
190182 ("airflow.operators.python" , "airflow.operators.python_operator" ),
191183 ("airflow.utils.session" , "airflow.utils.db" ),
184+ (
185+ "airflow.providers.cncf.kubernetes.operators.kubernetes_pod" ,
186+ "airflow.contrib.operators.kubernetes_pod_operator"
187+ ),
192188 ]
193189
194190 qry = Query ()
@@ -222,10 +218,20 @@ def _remover(node: LN, capture: Capture, filename: Filename) -> None:
222218 # Remove tags
223219 qry .select_method ("DAG" ).is_call ().modify (remove_tags_modifier )
224220
225- # Fix KubernetesPodOperator imports to use old path
226- qry .select_module (
227- "airflow.providers.cncf.kubernetes.operators.kubernetes_pod" ).rename (
228- "airflow.contrib.operators.kubernetes_pod_operator"
221+ # Fix AWS import in Google Cloud Transfer Service
222+ (
223+ qry
224+ .select_module ("airflow.providers.amazon.aws.hooks.base_aws" )
225+ .is_filename (include = r"cloud_storage_transfer_service\.py" )
226+ .rename ("airflow.contrib.hooks.aws_hook" )
227+ )
228+
229+ (
230+ qry
231+ .select_class ("AwsBaseHook" )
232+ .is_filename (include = r"cloud_storage_transfer_service\.py" )
233+ .filter (lambda n , c , f : n .type == 300 )
234+ .rename ("AwsHook" )
229235 )
230236
231237 # Fix BaseOperatorLinks imports
@@ -243,10 +249,28 @@ def _remover(node: LN, capture: Capture, filename: Filename) -> None:
243249 .modify (add_provide_context_to_python_operator )
244250 )
245251
252+ # Remove new class and rename usages of old
246253 remove_class (qry , "GKEStartPodOperator" )
254+ (
255+ qry
256+ .select_class ("GKEStartPodOperator" )
257+ .is_filename (include = r"example_kubernetes_engine\.py" )
258+ .rename ("GKEPodOperator" )
259+ )
247260
248261 qry .execute (write = True , silent = False , interactive = False )
249262
263+ # Add old import to GKE
264+ gke_path = os .path .join (
265+ dirname (__file__ ), "airflow" , "providers" , "google" , "cloud" , "operators" , "kubernetes_engine.py"
266+ )
267+ with open (gke_path , "a" ) as f :
268+ f .writelines (["" , "from airflow.contrib.operators.gcp_container_operator import GKEPodOperator" ])
269+
270+ gke_path = os .path .join (
271+ dirname (__file__ ), "airflow" , "providers" , "google" , "cloud" , "operators" , "kubernetes_engine.py"
272+ )
273+
250274
251275def get_source_providers_folder ():
252276 return os .path .join (dirname (__file__ ), os .pardir , "airflow" , "providers" )
0 commit comments