Skip to content

Commit 898b839

Browse files
committed
using dtypekind enum values
1 parent ee8e0c4 commit 898b839

2 files changed

Lines changed: 49 additions & 1 deletion

File tree

sktime/datatypes/_dtypekind.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from enum import IntEnum
2+
3+
4+
class DtypeKind(IntEnum):
5+
"""
6+
Integer enum for data types.
7+
8+
Attributes
9+
----------
10+
INT : int
11+
Matches to signed integer data type.
12+
UINT : int
13+
Matches to unsigned integer data type.
14+
FLOAT : int
15+
Matches to floating point data type.
16+
BOOL : int
17+
Matches to boolean data type.
18+
STRING : int
19+
Matches to string data type (UTF-8 encoded).
20+
DATETIME : int
21+
Matches to datetime data type.
22+
CATEGORICAL : int
23+
Matches to categorical data type.
24+
"""
25+
26+
INT = 0
27+
UINT = 1
28+
FLOAT = 2
29+
BOOL = 20
30+
STRING = 21 # UTF-8
31+
DATETIME = 22
32+
CATEGORICAL = 23

sktime/datatypes/_series/_check.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
from sktime.datatypes._common import _req
4646
from sktime.datatypes._common import _ret as ret
47+
from sktime.datatypes._dtypekind import DtypeKind
4748
from sktime.utils.validation._dependencies import _check_soft_dependencies
4849
from sktime.utils.validation.series import is_in_valid_index_types
4950

@@ -73,8 +74,23 @@ def check_pddataframe_series(obj, return_metadata=False, var_name="obj"):
7374
metadata["n_features"] = len(obj.columns)
7475
if _req("feature_names", return_metadata):
7576
metadata["feature_names"] = obj.columns.to_list()
77+
7678
if _req("feature_kind", return_metadata):
77-
metadata["feature_kind"] = obj.dtypes.to_list()
79+
col_dtypes = obj.dtypes.to_list()
80+
for i, dtype in enumerate(col_dtypes):
81+
if dtype == float:
82+
col_dtypes[i] = DtypeKind.FLOAT
83+
elif dtype == int:
84+
col_dtypes[i] = DtypeKind.INT
85+
elif dtype == np.uint:
86+
col_dtypes[i] = DtypeKind.UINT
87+
elif dtype == object:
88+
col_dtypes[i] = DtypeKind.CATEGORICAL
89+
elif dtype == bool:
90+
col_dtypes[i] = DtypeKind.BOOL
91+
elif dtype == pd.DatetimeIndex:
92+
col_dtypes[i] = DtypeKind.DATETIME
93+
metadata["feature_kind"] = col_dtypes
7894

7995
# check that columns are unique
8096
if not obj.columns.is_unique:

0 commit comments

Comments
 (0)