2424import warnings
2525from collections import OrderedDict
2626from tempfile import NamedTemporaryFile , TemporaryDirectory
27- from typing import Any , Dict , List , Optional , Union
27+ from typing import Any , Dict , Iterable , List , Mapping , Optional , Union
2828
2929import pandas
3030import unicodecsv as csv
@@ -857,15 +857,15 @@ def get_conn(self, schema: Optional[str] = None) -> Any:
857857
858858 def _get_results (
859859 self ,
860- hql : Union [str , List [str ]],
860+ sql : Union [str , List [str ]],
861861 schema : str = 'default' ,
862862 fetch_size : Optional [int ] = None ,
863- hive_conf : Optional [Dict [ Any , Any ]] = None ,
863+ hive_conf : Optional [Union [ Iterable , Mapping ]] = None ,
864864 ) -> Any :
865865 from pyhive .exc import ProgrammingError
866866
867- if isinstance (hql , str ):
868- hql = [hql ]
867+ if isinstance (sql , str ):
868+ sql = [sql ]
869869 previous_description = None
870870 with contextlib .closing (self .get_conn (schema )) as conn , contextlib .closing (conn .cursor ()) as cur :
871871
@@ -882,7 +882,7 @@ def _get_results(
882882 for k , v in env_context .items ():
883883 cur .execute (f"set { k } ={ v } " )
884884
885- for statement in hql :
885+ for statement in sql :
886886 cur .execute (statement )
887887 # we only get results of statements that returns
888888 lowered_statement = statement .lower ().strip ()
@@ -911,29 +911,29 @@ def _get_results(
911911
912912 def get_results (
913913 self ,
914- hql : str ,
914+ sql : Union [ str , List [ str ]] ,
915915 schema : str = 'default' ,
916916 fetch_size : Optional [int ] = None ,
917- hive_conf : Optional [Dict [ Any , Any ]] = None ,
917+ hive_conf : Optional [Union [ Iterable , Mapping ]] = None ,
918918 ) -> Dict [str , Any ]:
919919 """
920920 Get results of the provided hql in target schema.
921921
922- :param hql : hql to be executed.
922+ :param sql : hql to be executed.
923923 :param schema: target schema, default to 'default'.
924924 :param fetch_size: max size of result to fetch.
925925 :param hive_conf: hive_conf to execute alone with the hql.
926926 :return: results of hql execution, dict with data (list of results) and header
927927 :rtype: dict
928928 """
929- results_iter = self ._get_results (hql , schema , fetch_size = fetch_size , hive_conf = hive_conf )
929+ results_iter = self ._get_results (sql , schema , fetch_size = fetch_size , hive_conf = hive_conf )
930930 header = next (results_iter )
931931 results = {'data' : list (results_iter ), 'header' : header }
932932 return results
933933
934934 def to_csv (
935935 self ,
936- hql : str ,
936+ sql : str ,
937937 csv_filepath : str ,
938938 schema : str = 'default' ,
939939 delimiter : str = ',' ,
@@ -945,7 +945,7 @@ def to_csv(
945945 """
946946 Execute hql in target schema and write results to a csv file.
947947
948- :param hql : hql to be executed.
948+ :param sql : hql to be executed.
949949 :param csv_filepath: filepath of csv to write results into.
950950 :param schema: target schema, default to 'default'.
951951 :param delimiter: delimiter of the csv file, default to ','.
@@ -955,7 +955,7 @@ def to_csv(
955955 :param hive_conf: hive_conf to execute alone with the hql.
956956
957957 """
958- results_iter = self ._get_results (hql , schema , fetch_size = fetch_size , hive_conf = hive_conf )
958+ results_iter = self ._get_results (sql , schema , fetch_size = fetch_size , hive_conf = hive_conf )
959959 header = next (results_iter )
960960 message = None
961961
@@ -982,14 +982,14 @@ def to_csv(
982982 self .log .info ("Done. Loaded a total of %s rows." , i )
983983
984984 def get_records (
985- self , hql : str , schema : str = 'default' , hive_conf : Optional [Dict [ Any , Any ]] = None
985+ self , sql : Union [ str , List [ str ]], parameters : Optional [Union [ Iterable , Mapping ]] = None , ** kwargs
986986 ) -> Any :
987987 """
988- Get a set of records from a Hive query.
988+ Get a set of records from a Hive query. You can optionally pass 'schema' kwarg
989+ which specifies target schema and default to 'default'.
989990
990- :param hql: hql to be executed.
991- :param schema: target schema, default to 'default'.
992- :param hive_conf: hive_conf to execute alone with the hql.
991+ :param sql: hql to be executed.
992+ :param parameters: optional configuration passed to get_results
993993 :return: result of hive execution
994994 :rtype: list
995995
@@ -998,19 +998,20 @@ def get_records(
998998 >>> len(hh.get_records(sql))
999999 100
10001000 """
1001- return self .get_results (hql , schema = schema , hive_conf = hive_conf )['data' ]
1001+ schema = kwargs ['schema' ] if 'schema' in kwargs else 'default'
1002+ return self .get_results (sql , schema = schema , hive_conf = parameters )['data' ]
10021003
10031004 def get_pandas_df ( # type: ignore
10041005 self ,
1005- hql : str ,
1006+ sql : str ,
10061007 schema : str = 'default' ,
10071008 hive_conf : Optional [Dict [Any , Any ]] = None ,
10081009 ** kwargs ,
10091010 ) -> pandas .DataFrame :
10101011 """
10111012 Get a pandas dataframe from a Hive query
10121013
1013- :param hql : hql to be executed.
1014+ :param sql : hql to be executed.
10141015 :param schema: target schema, default to 'default'.
10151016 :param hive_conf: hive_conf to execute alone with the hql.
10161017 :param kwargs: (optional) passed into pandas.DataFrame constructor
@@ -1025,6 +1026,6 @@ def get_pandas_df( # type: ignore
10251026
10261027 :return: pandas.DateFrame
10271028 """
1028- res = self .get_results (hql , schema = schema , hive_conf = hive_conf )
1029+ res = self .get_results (sql , schema = schema , hive_conf = hive_conf )
10291030 df = pandas .DataFrame (res ['data' ], columns = [c [0 ] for c in res ['header' ]], ** kwargs )
10301031 return df
0 commit comments