2121import gzip as gz
2222import os
2323import shutil
24+ import time
2425import warnings
2526from contextlib import contextmanager
2627from datetime import datetime
28+ from functools import partial
2729from io import BytesIO
2830from os import path
2931from tempfile import NamedTemporaryFile
3234
3335from google .api_core .exceptions import NotFound
3436from google .cloud import storage
37+ from google .cloud .exceptions import GoogleCloudError
3538
3639from airflow .exceptions import AirflowException
3740from airflow .providers .google .common .hooks .base_google import GoogleBaseHook
41+ from airflow .utils import timezone
3842from airflow .version import version
3943
4044RT = TypeVar ('RT' ) # pylint: disable=invalid-name
@@ -266,6 +270,7 @@ def download(
266270 filename : Optional [str ] = None ,
267271 chunk_size : Optional [int ] = None ,
268272 timeout : Optional [int ] = DEFAULT_TIMEOUT ,
273+ num_max_attempts : Optional [int ] = 1 ,
269274 ) -> Union [str , bytes ]:
270275 """
271276 Downloads a file from Google Cloud Storage.
@@ -285,20 +290,43 @@ def download(
285290 :type chunk_size: int
286291 :param timeout: Request timeout in seconds.
287292 :type timeout: int
293+ :param num_max_attempts: Number of attempts to download the file.
294+ :type num_max_attempts: int
288295 """
289296 # TODO: future improvement check file size before downloading,
290297 # to check for local space availability
291298
292- client = self .get_conn ()
293- bucket = client .bucket (bucket_name )
294- blob = bucket .blob (blob_name = object_name , chunk_size = chunk_size )
295-
296- if filename :
297- blob .download_to_filename (filename , timeout = timeout )
298- self .log .info ('File downloaded to %s' , filename )
299- return filename
300- else :
301- return blob .download_as_string ()
299+ num_file_attempts = 0
300+
301+ while num_file_attempts < num_max_attempts :
302+ try :
303+ num_file_attempts += 1
304+ client = self .get_conn ()
305+ bucket = client .bucket (bucket_name )
306+ blob = bucket .blob (blob_name = object_name , chunk_size = chunk_size )
307+
308+ if filename :
309+ blob .download_to_filename (filename , timeout = timeout )
310+ self .log .info ('File downloaded to %s' , filename )
311+ return filename
312+ else :
313+ return blob .download_as_string ()
314+
315+ except GoogleCloudError :
316+ if num_file_attempts == num_max_attempts :
317+ self .log .error (
318+ 'Download attempt of object: %s from %s has failed. Attempt: %s, max %s.' ,
319+ object_name ,
320+ object_name ,
321+ num_file_attempts ,
322+ num_max_attempts ,
323+ )
324+ raise
325+
326+ # Wait with exponential backoff scheme before retrying.
327+ timeout_seconds = 1.0 * 2 ** (num_file_attempts - 1 )
328+ time .sleep (timeout_seconds )
329+ continue
302330
303331 @_fallback_object_url_to_object_name_and_bucket_name ()
304332 @contextmanager
@@ -362,7 +390,7 @@ def provide_file_and_upload(
362390 tmp_file .flush ()
363391 self .upload (bucket_name = bucket_name , object_name = object_name , filename = tmp_file .name )
364392
365- def upload (
393+ def upload ( # pylint: disable=too-many-arguments
366394 self ,
367395 bucket_name : str ,
368396 object_name : str ,
@@ -373,6 +401,7 @@ def upload(
373401 encoding : str = 'utf-8' ,
374402 chunk_size : Optional [int ] = None ,
375403 timeout : Optional [int ] = DEFAULT_TIMEOUT ,
404+ num_max_attempts : int = 1 ,
376405 ) -> None :
377406 """
378407 Uploads a local file or file data as string or bytes to Google Cloud Storage.
@@ -395,7 +424,38 @@ def upload(
395424 :type chunk_size: int
396425 :param timeout: Request timeout in seconds.
397426 :type timeout: int
427+ :param num_max_attempts: Number of attempts to try to upload the file.
428+ :type num_max_attempts: int
398429 """
430+
431+ def _call_with_retry (f : Callable [[], None ]) -> None :
432+ """Helper functions to upload a file or a string with a retry mechanism and exponential back-off.
433+ :param f: Callable that should be retried.
434+ :type f: Callable[[], None]
435+ """
436+ num_file_attempts = 0
437+
438+ while num_file_attempts < num_max_attempts :
439+ try :
440+ num_file_attempts += 1
441+ f ()
442+
443+ except GoogleCloudError as e :
444+ if num_file_attempts == num_max_attempts :
445+ self .log .error (
446+ 'Upload attempt of object: %s from %s has failed. Attempt: %s, max %s.' ,
447+ object_name ,
448+ object_name ,
449+ num_file_attempts ,
450+ num_max_attempts ,
451+ )
452+ raise e
453+
454+ # Wait with exponential backoff scheme before retrying.
455+ timeout_seconds = 1.0 * 2 ** (num_file_attempts - 1 )
456+ time .sleep (timeout_seconds )
457+ continue
458+
399459 client = self .get_conn ()
400460 bucket = client .bucket (bucket_name )
401461 blob = bucket .blob (blob_name = object_name , chunk_size = chunk_size )
@@ -416,7 +476,10 @@ def upload(
416476 shutil .copyfileobj (f_in , f_out )
417477 filename = filename_gz
418478
419- blob .upload_from_filename (filename = filename , content_type = mime_type , timeout = timeout )
479+ _call_with_retry (
480+ partial (blob .upload_from_filename , filename = filename , content_type = mime_type , timeout = timeout )
481+ )
482+
420483 if gzip :
421484 os .remove (filename )
422485 self .log .info ('File %s uploaded to %s in %s bucket' , filename , object_name , bucket_name )
@@ -430,7 +493,9 @@ def upload(
430493 with gz .GzipFile (fileobj = out , mode = "w" ) as f :
431494 f .write (data )
432495 data = out .getvalue ()
433- blob .upload_from_string (data , content_type = mime_type , timeout = timeout )
496+
497+ _call_with_retry (partial (blob .upload_from_string , data , content_type = mime_type , timeout = timeout ))
498+
434499 self .log .info ('Data stream uploaded to %s in %s bucket' , object_name , bucket_name )
435500 else :
436501 raise ValueError ("'filename' and 'data' parameter missing. One is required to upload to gcs." )
@@ -481,10 +546,9 @@ def is_updated_after(self, bucket_name: str, object_name: str, ts: datetime) ->
481546 """
482547 blob_update_time = self .get_blob_update_time (bucket_name , object_name )
483548 if blob_update_time is not None :
484- import dateutil .tz
485549
486550 if not ts .tzinfo :
487- ts = ts .replace (tzinfo = dateutil . tz . tzutc () )
551+ ts = ts .replace (tzinfo = timezone . utc )
488552 self .log .info ("Verify object date: %s > %s" , blob_update_time , ts )
489553 if blob_update_time > ts :
490554 return True
@@ -508,12 +572,11 @@ def is_updated_between(
508572 """
509573 blob_update_time = self .get_blob_update_time (bucket_name , object_name )
510574 if blob_update_time is not None :
511- import dateutil .tz
512575
513576 if not min_ts .tzinfo :
514- min_ts = min_ts .replace (tzinfo = dateutil . tz . tzutc () )
577+ min_ts = min_ts .replace (tzinfo = timezone . utc )
515578 if not max_ts .tzinfo :
516- max_ts = max_ts .replace (tzinfo = dateutil . tz . tzutc () )
579+ max_ts = max_ts .replace (tzinfo = timezone . utc )
517580 self .log .info ("Verify object date: %s is between %s and %s" , blob_update_time , min_ts , max_ts )
518581 if min_ts <= blob_update_time < max_ts :
519582 return True
@@ -533,10 +596,9 @@ def is_updated_before(self, bucket_name: str, object_name: str, ts: datetime) ->
533596 """
534597 blob_update_time = self .get_blob_update_time (bucket_name , object_name )
535598 if blob_update_time is not None :
536- import dateutil .tz
537599
538600 if not ts .tzinfo :
539- ts = ts .replace (tzinfo = dateutil . tz . tzutc () )
601+ ts = ts .replace (tzinfo = timezone . utc )
540602 self .log .info ("Verify object date: %s < %s" , blob_update_time , ts )
541603 if blob_update_time < ts :
542604 return True
@@ -558,8 +620,6 @@ def is_older_than(self, bucket_name: str, object_name: str, seconds: int) -> boo
558620 if blob_update_time is not None :
559621 from datetime import timedelta
560622
561- from airflow .utils import timezone
562-
563623 current_time = timezone .utcnow ()
564624 given_time = current_time - timedelta (seconds = seconds )
565625 self .log .info ("Verify object date: %s is older than %s" , blob_update_time , given_time )
@@ -650,6 +710,69 @@ def list(self, bucket_name, versions=None, max_results=None, prefix=None, delimi
650710 break
651711 return ids
652712
713+ def list_by_timespan (
714+ self ,
715+ bucket_name : str ,
716+ timespan_start : datetime ,
717+ timespan_end : datetime ,
718+ versions : bool = None ,
719+ max_results : int = None ,
720+ prefix : str = None ,
721+ delimiter : str = None ,
722+ ) -> list :
723+ """
724+ List all objects from the bucket with the give string prefix in name that were
725+ updated in the time between ``timespan_start`` and ``timespan_end``.
726+
727+ :param bucket_name: bucket name
728+ :type bucket_name: str
729+ :param timespan_start: will return objects that were updated at or after this datetime (UTC)
730+ :type timespan_start: datetime
731+ :param timespan_end: will return objects that were updated before this datetime (UTC)
732+ :type timespan_end: datetime
733+ :param versions: if true, list all versions of the objects
734+ :type versions: bool
735+ :param max_results: max count of items to return in a single page of responses
736+ :type max_results: int
737+ :param prefix: prefix string which filters objects whose name begin with
738+ this prefix
739+ :type prefix: str
740+ :param delimiter: filters objects based on the delimiter (for e.g '.csv')
741+ :type delimiter: str
742+ :return: a stream of object names matching the filtering criteria
743+ """
744+ client = self .get_conn ()
745+ bucket = client .bucket (bucket_name )
746+
747+ ids = []
748+ page_token = None
749+
750+ while True :
751+ blobs = bucket .list_blobs (
752+ max_results = max_results ,
753+ page_token = page_token ,
754+ prefix = prefix ,
755+ delimiter = delimiter ,
756+ versions = versions ,
757+ )
758+
759+ blob_names = []
760+ for blob in blobs :
761+ if timespan_start <= blob .updated .replace (tzinfo = timezone .utc ) < timespan_end :
762+ blob_names .append (blob .name )
763+
764+ prefixes = blobs .prefixes
765+ if prefixes :
766+ ids += list (prefixes )
767+ else :
768+ ids += blob_names
769+
770+ page_token = blobs .next_page_token
771+ if page_token is None :
772+ # empty next page token
773+ break
774+ return ids
775+
653776 def get_size (self , bucket_name : str , object_name : str ) -> int :
654777 """
655778 Gets the size of a file in Google Cloud Storage.
0 commit comments