Skip to content

Commit dc4640b

Browse files
authored
feat: consistent selector parameters (#983)
### Summary of Changes - Name all parameters `selector` that select a subset of the columns of a `Table` by name or (later) by a `ColumnSelector`. - Parameters of `Table.join` and `Table.to_tabular_dataset` are deliberately unchanged, since I want users to always specify column names here explicitly. - Remove the ability to pass a lambda predicate to `select_columns`. - This was inconsistent with the other `selector` parameters. - It was also quite slow, since we sequentially looped over the columns. The upcoming `ColumnSelector`s will cover common cases and be more performant. - For all other cases, simply use `Table.to_columns`, a list comprehension, and `Table.from_columns`,
1 parent 2db9069 commit dc4640b

33 files changed

Lines changed: 220 additions & 253 deletions

docs/tutorials/classification.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@
208208
"source": [
209209
"from safeds.data.tabular.transformation import SimpleImputer\n",
210210
"\n",
211-
"simple_imputer = SimpleImputer(column_names=[\"age\", \"fare\"], strategy=SimpleImputer.Strategy.mean())\n",
211+
"simple_imputer = SimpleImputer(selector=[\"age\", \"fare\"], strategy=SimpleImputer.Strategy.mean())\n",
212212
"fitted_simple_imputer_train, transformed_train_data = simple_imputer.fit_and_transform(train_table)\n",
213213
"transformed_test_data = fitted_simple_imputer_train.transform(test_table)"
214214
]
@@ -241,7 +241,7 @@
241241
"from safeds.data.tabular.transformation import OneHotEncoder\n",
242242
"\n",
243243
"fitted_one_hot_encoder_train, transformed_train_data = OneHotEncoder(\n",
244-
" column_names=[\"sex\", \"port_embarked\"],\n",
244+
" selector=[\"sex\", \"port_embarked\"],\n",
245245
").fit_and_transform(transformed_train_data)\n",
246246
"transformed_test_data = fitted_one_hot_encoder_train.transform(transformed_test_data)"
247247
]

docs/tutorials/data_processing.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@
510510
"source": [
511511
"from safeds.data.tabular.transformation import SimpleImputer\n",
512512
"\n",
513-
"imputer = SimpleImputer(SimpleImputer.Strategy.constant(0), column_names=[\"age\", \"fare\", \"cabin\", \"port_embarked\"]).fit(\n",
513+
"imputer = SimpleImputer(SimpleImputer.Strategy.constant(0), selector=[\"age\", \"fare\", \"cabin\", \"port_embarked\"]).fit(\n",
514514
" titanic,\n",
515515
")\n",
516516
"imputer.transform(titanic_slice)"
@@ -583,7 +583,7 @@
583583
"source": [
584584
"from safeds.data.tabular.transformation import LabelEncoder\n",
585585
"\n",
586-
"encoder = LabelEncoder(column_names=[\"sex\", \"port_embarked\"]).fit(titanic)\n",
586+
"encoder = LabelEncoder(selector=[\"sex\", \"port_embarked\"]).fit(titanic)\n",
587587
"encoder.transform(titanic_slice)"
588588
]
589589
},
@@ -674,7 +674,7 @@
674674
"source": [
675675
"from safeds.data.tabular.transformation import OneHotEncoder\n",
676676
"\n",
677-
"encoder = OneHotEncoder(column_names=[\"sex\", \"port_embarked\"]).fit(titanic)\n",
677+
"encoder = OneHotEncoder(selector=[\"sex\", \"port_embarked\"]).fit(titanic)\n",
678678
"encoder.transform(titanic_slice)"
679679
]
680680
},
@@ -745,7 +745,7 @@
745745
"source": [
746746
"from safeds.data.tabular.transformation import RangeScaler\n",
747747
"\n",
748-
"scaler = RangeScaler(column_names=\"age\", min_=0.0, max_=1.0).fit(titanic)\n",
748+
"scaler = RangeScaler(selector=\"age\", min_=0.0, max_=1.0).fit(titanic)\n",
749749
"scaler.transform(titanic_slice)"
750750
]
751751
},
@@ -816,7 +816,7 @@
816816
"source": [
817817
"from safeds.data.tabular.transformation import StandardScaler\n",
818818
"\n",
819-
"scaler = StandardScaler(column_names=[\"age\", \"travel_class\"]).fit(titanic)\n",
819+
"scaler = StandardScaler(selector=[\"age\", \"travel_class\"]).fit(titanic)\n",
820820
"scaler.transform(titanic_slice)"
821821
]
822822
},

src/safeds/_validation/_check_bounds_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _check_bounds(
3636
if actual is None:
3737
return # Skip the check if the actual value is None (i.e., not provided).
3838

39-
if lower_bound is None:
39+
if lower_bound is None: # pragma: no cover
4040
lower_bound = _OpenBound(float("-inf"))
4141
if upper_bound is None:
4242
upper_bound = _OpenBound(float("inf"))
@@ -148,7 +148,7 @@ def _to_string_as_upper_bound(self) -> str:
148148

149149

150150
def _float_to_string(value: float) -> str:
151-
if value == float("-inf"):
151+
if value == float("-inf"): # pragma: no cover
152152
return "-\u221e"
153153
elif value == float("inf"):
154154
return "\u221e"

src/safeds/_validation/_check_column_is_numeric_module.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,21 +49,19 @@ def _check_column_is_numeric(
4949

5050
def _check_columns_are_numeric(
5151
table_or_schema: Table | Schema,
52-
column_names: str | list[str],
52+
selector: str | list[str],
5353
*,
5454
operation: str = "do a numeric operation",
5555
) -> None:
5656
"""
57-
Check if the columns with the specified names are numeric and raise an error if they are not.
58-
59-
Missing columns are ignored. Use `_check_columns_exist` to check for missing columns.
57+
Check if the specified columns are numeric and raise an error if they are not. Missing columns are ignored.
6058
6159
Parameters
6260
----------
6361
table_or_schema:
6462
The table or schema to check.
65-
column_names:
66-
The column names to check.
63+
selector:
64+
The columns to check.
6765
operation:
6866
The operation that is performed on the columns. This is used in the error message.
6967
@@ -76,17 +74,17 @@ def _check_columns_are_numeric(
7674

7775
if isinstance(table_or_schema, Table):
7876
table_or_schema = table_or_schema.schema
79-
if isinstance(column_names, str):
80-
column_names = [column_names]
77+
if isinstance(selector, str): # pragma: no cover
78+
selector = [selector]
8179

82-
if len(column_names) > 1:
80+
if len(selector) > 1:
8381
# Create a set for faster containment checks
8482
known_names: Container = set(table_or_schema.column_names)
8583
else:
8684
known_names = table_or_schema.column_names
8785

8886
non_numeric_names = [
89-
name for name in column_names if name in known_names and not table_or_schema.get_column_type(name).is_numeric
87+
name for name in selector if name in known_names and not table_or_schema.get_column_type(name).is_numeric
9088
]
9189
if non_numeric_names:
9290
message = _build_error_message(non_numeric_names, operation)

src/safeds/_validation/_check_schema_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def _check_schema(
7575

7676

7777
def _check_types(expected_schema: Schema, actual_schema: Schema, *, check_types: _TypeCheckingMode) -> None:
78-
if check_types == "off":
78+
if check_types == "off": # pragma: no cover
7979
return
8080

8181
mismatched_types: list[tuple[str, pl.DataType, pl.DataType]] = []

src/safeds/data/labeled/containers/_image_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def __init__(self, column: Column) -> None:
448448
)
449449
# TODO: should not one-hot-encode the target. label encoding without order is sufficient. should also not
450450
# be done automatically?
451-
self._one_hot_encoder = OneHotEncoder(column_names=self._column_name).fit(column_as_table)
451+
self._one_hot_encoder = OneHotEncoder(selector=self._column_name).fit(column_as_table)
452452
self._tensor = torch.Tensor(
453453
self._one_hot_encoder.transform(column_as_table)._data_frame.to_torch(dtype=pl.Float32),
454454
).to(_get_device())

src/safeds/data/tabular/containers/_table.py

Lines changed: 31 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,7 @@ def has_column(self, name: str) -> bool:
775775

776776
def remove_columns(
777777
self,
778-
names: str | list[str],
778+
selector: str | list[str],
779779
*,
780780
ignore_unknown_names: bool = False,
781781
) -> Table:
@@ -786,8 +786,8 @@ def remove_columns(
786786
787787
Parameters
788788
----------
789-
names:
790-
The names of the columns to remove.
789+
selector:
790+
The columns to remove.
791791
ignore_unknown_names:
792792
If set to True, columns that are not present in the table will be ignored.
793793
If set to False, an error will be raised if any of the specified columns do not exist.
@@ -831,18 +831,18 @@ def remove_columns(
831831
Related
832832
-------
833833
- [select_columns][safeds.data.tabular.containers._table.Table.select_columns]:
834-
Keep only a subset of the columns. This method accepts either column names, or a predicate.
834+
Keep only a subset of the columns.
835835
- [remove_columns_with_missing_values][safeds.data.tabular.containers._table.Table.remove_columns_with_missing_values]
836836
- [remove_non_numeric_columns][safeds.data.tabular.containers._table.Table.remove_non_numeric_columns]
837837
"""
838-
if isinstance(names, str):
839-
names = [names]
838+
if isinstance(selector, str):
839+
selector = [selector]
840840

841841
if not ignore_unknown_names:
842-
_check_columns_exist(self, names)
842+
_check_columns_exist(self, selector)
843843

844844
return Table._from_polars_lazy_frame(
845-
self._lazy_frame.drop(names, strict=not ignore_unknown_names),
845+
self._lazy_frame.drop(selector, strict=not ignore_unknown_names),
846846
)
847847

848848
def remove_columns_with_missing_values(
@@ -900,7 +900,7 @@ def remove_columns_with_missing_values(
900900
- [KNearestNeighborsImputer][safeds.data.tabular.transformation._k_nearest_neighbors_imputer.KNearestNeighborsImputer]:
901901
Replace missing values with a value computed from the nearest neighbors.
902902
- [select_columns][safeds.data.tabular.containers._table.Table.select_columns]:
903-
Keep only a subset of the columns. This method accepts either column names, or a predicate.
903+
Keep only a subset of the columns.
904904
- [remove_columns][safeds.data.tabular.containers._table.Table.remove_columns]:
905905
Remove columns from the table by name.
906906
- [remove_non_numeric_columns][safeds.data.tabular.containers._table.Table.remove_non_numeric_columns]
@@ -955,7 +955,7 @@ def remove_non_numeric_columns(self) -> Table:
955955
Related
956956
-------
957957
- [select_columns][safeds.data.tabular.containers._table.Table.select_columns]:
958-
Keep only a subset of the columns. This method accepts either column names, or a predicate.
958+
Keep only a subset of the columns.
959959
- [remove_columns][safeds.data.tabular.containers._table.Table.remove_columns]:
960960
Remove columns from the table by name.
961961
- [remove_columns_with_missing_values][safeds.data.tabular.containers._table.Table.remove_columns_with_missing_values]
@@ -1113,21 +1113,17 @@ def replace_column(
11131113

11141114
def select_columns(
11151115
self,
1116-
selector: str | list[str] | Callable[[Column], bool],
1116+
selector: str | list[str],
11171117
) -> Table:
11181118
"""
11191119
Select a subset of the columns and return the result as a new table.
11201120
1121-
**Notes:**
1122-
1123-
- The original table is not modified.
1124-
- If the `selector` is a custom function, this operation must fully load the data into memory, which can be
1125-
expensive.
1121+
**Note:** The original table is not modified.
11261122
11271123
Parameters
11281124
----------
11291125
selector:
1130-
The names of the columns to keep, or a predicate that decides whether to keep a column.
1126+
The columns to keep.
11311127
11321128
Returns
11331129
-------
@@ -1161,23 +1157,11 @@ def select_columns(
11611157
- [remove_columns_with_missing_values][safeds.data.tabular.containers._table.Table.remove_columns_with_missing_values]
11621158
- [remove_non_numeric_columns][safeds.data.tabular.containers._table.Table.remove_non_numeric_columns]
11631159
"""
1164-
import polars as pl
1165-
1166-
# Select by predicate
1167-
if callable(selector):
1168-
return Table._from_polars_lazy_frame(
1169-
pl.LazyFrame(
1170-
[column._series for column in self.to_columns() if selector(column)],
1171-
),
1172-
)
1173-
1174-
# Select by column names
1175-
else:
1176-
_check_columns_exist(self, selector)
1160+
_check_columns_exist(self, selector)
11771161

1178-
return Table._from_polars_lazy_frame(
1179-
self._lazy_frame.select(selector),
1180-
)
1162+
return Table._from_polars_lazy_frame(
1163+
self._lazy_frame.select(selector),
1164+
)
11811165

11821166
def transform_columns(
11831167
self,
@@ -1611,7 +1595,7 @@ def remove_rows_by_column(
16111595
def remove_rows_with_missing_values(
16121596
self,
16131597
*,
1614-
column_names: str | list[str] | None = None,
1598+
selector: str | list[str] | None = None,
16151599
) -> Table:
16161600
"""
16171601
Remove rows that contain missing values in the specified columns and return the result as a new table.
@@ -1624,8 +1608,8 @@ def remove_rows_with_missing_values(
16241608
16251609
Parameters
16261610
----------
1627-
column_names:
1628-
The names of the columns to check. If None, all columns are checked.
1611+
selector:
1612+
The columns to check. If None, all columns are checked.
16291613
16301614
Returns
16311615
-------
@@ -1645,7 +1629,7 @@ def remove_rows_with_missing_values(
16451629
| 1 | 4 |
16461630
+-----+-----+
16471631
1648-
>>> table.remove_rows_with_missing_values(column_names=["b"])
1632+
>>> table.remove_rows_with_missing_values(selector=["b"])
16491633
+------+-----+
16501634
| a | b |
16511635
| --- | --- |
@@ -1669,18 +1653,18 @@ def remove_rows_with_missing_values(
16691653
- [remove_duplicate_rows][safeds.data.tabular.containers._table.Table.remove_duplicate_rows]
16701654
- [remove_rows_with_outliers][safeds.data.tabular.containers._table.Table.remove_rows_with_outliers]
16711655
"""
1672-
if isinstance(column_names, list) and not column_names:
1656+
if isinstance(selector, list) and not selector:
16731657
# polars panics in this case
16741658
return self
16751659

16761660
return Table._from_polars_lazy_frame(
1677-
self._lazy_frame.drop_nulls(subset=column_names),
1661+
self._lazy_frame.drop_nulls(subset=selector),
16781662
)
16791663

16801664
def remove_rows_with_outliers(
16811665
self,
16821666
*,
1683-
column_names: str | list[str] | None = None,
1667+
selector: str | list[str] | None = None,
16841668
z_score_threshold: float = 3,
16851669
) -> Table:
16861670
"""
@@ -1701,8 +1685,8 @@ def remove_rows_with_outliers(
17011685
17021686
Parameters
17031687
----------
1704-
column_names:
1705-
Names of the columns to consider. If None, all numeric columns are considered.
1688+
selector:
1689+
The columns to check. If None, all columns are checked.
17061690
z_score_threshold:
17071691
The z-score threshold for detecting outliers. Must be greater than or equal to 0.
17081692
@@ -1755,14 +1739,14 @@ def remove_rows_with_outliers(
17551739
lower_bound=_ClosedBound(0),
17561740
)
17571741

1758-
if column_names is None:
1759-
column_names = self.column_names
1742+
if selector is None:
1743+
selector = self.column_names
17601744

17611745
import polars as pl
17621746
import polars.selectors as cs
17631747

17641748
# polar's `all_horizontal` raises a `ComputeError` if there are no columns
1765-
selected = self._lazy_frame.select(cs.numeric() & cs.by_name(column_names))
1749+
selected = self._lazy_frame.select(cs.numeric() & cs.by_name(selector))
17661750
if not selected.collect_schema().names():
17671751
return self
17681752

@@ -2268,9 +2252,9 @@ def join(
22682252
right_table:
22692253
The table to join with the left table.
22702254
left_names:
2271-
Name or list of names of columns to join on in the left table.
2255+
Names of columns to join on in the left table.
22722256
right_names:
2273-
Name or list of names of columns to join on in the right table.
2257+
Names of columns to join on in the right table.
22742258
mode:
22752259
Specify which type of join you want to use.
22762260

0 commit comments

Comments
 (0)