Skip to content

Commit 58287f0

Browse files
committed
Add support for polars DataFrame and LazyFrame
1 parent ed0ac9f commit 58287f0

1 file changed

Lines changed: 32 additions & 22 deletions

File tree

snakemake/utils.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -111,32 +111,42 @@ def set_defaults(validator, properties, instance, schema):
111111
if not isinstance(data, dict):
112112
try:
113113
import pandas as pd
114+
import pandas as pl
114115

115-
recordlist = []
116+
records = []
116117
if isinstance(data, pd.DataFrame):
117-
for i, record in enumerate(data.to_dict("records")):
118-
record = {k: v for k, v in record.items() if not pd.isnull(v)}
119-
try:
120-
if set_default:
121-
DefaultValidator(schema, resolver=resolver).validate(record)
122-
recordlist.append(record)
123-
else:
124-
jsonschema.validate(record, schema, resolver=resolver)
125-
except jsonschema.exceptions.ValidationError as e:
126-
raise WorkflowError(
127-
f"Error validating row {i} of data frame.", e
128-
)
129-
if set_default:
130-
newdata = pd.DataFrame(recordlist, data.index)
131-
newcol = ~newdata.columns.isin(data.columns)
132-
n = len(data.columns)
133-
for col in newdata.loc[:, newcol].columns:
134-
data.insert(n, col, newdata.loc[:, col])
135-
n = n + 1
136-
return
118+
records = data.to_dict("records")
119+
elif isinstance(data, pl.DataFrame):
120+
records = data.iter_rows(named=True)
121+
elif isinstance(data, pl.LazyFrame):
122+
# If a LazyFrame is being used, probably it is a large dataframe (so check only first 1000 records)
123+
records = data.head(1000).collect().iter_rows(named=True)
124+
else:
125+
raise WorkflowError("Unsupported data type for validation.")
126+
127+
recordlist = []
128+
for i, record in enumerate(records):
129+
# Exclude NULL values
130+
record = {k: v for k, v in record.items() if not pd.isnull(v)}
131+
try:
132+
if set_default:
133+
DefaultValidator(schema, resolver=resolver).validate(record)
134+
recordlist.append(record)
135+
else:
136+
jsonschema.validate(record, schema, resolver=resolver)
137+
except jsonschema.exceptions.ValidationError as e:
138+
raise WorkflowError(f"Error validating row {i} of data frame.", e)
139+
if set_default:
140+
newdata = pd.DataFrame(recordlist, data.index)
141+
newcol = ~newdata.columns.isin(data.columns)
142+
n = len(data.columns)
143+
for col in newdata.loc[:, newcol].columns:
144+
data.insert(n, col, newdata.loc[:, col])
145+
n = n + 1
146+
return
137147
except ImportError:
138148
pass
139-
raise WorkflowError("Unsupported data type for validation.")
149+
raise WorkflowError("Error validating data frame.")
140150
else:
141151
try:
142152
if set_default:

0 commit comments

Comments
 (0)