Skip to content

Commit f88ea80

Browse files
committed
fix: rework UWheelOptimizer to use OptimizerRule, closes #11
1 parent cfaabbb commit f88ea80

10 files changed

Lines changed: 368 additions & 560 deletions

File tree

Cargo.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ resolver = "2"
33
members = ["benchmarks/nyc_taxi_bench", "datafusion-uwheel"]
44

55
[workspace.package]
6-
version = "38.0.0"
6+
version = "40.0.0"
77
edition = "2021"
88
authors = ["Max Meldrum <max@meldrum.se>"]
99
license = "Apache-2.0"
@@ -19,9 +19,7 @@ uwheel = { version = "0.2.0", default-features = false, features = [
1919
"max",
2020
"all",
2121
] }
22-
datafusion = "38.0.0"
23-
async-trait = "0.1.81"
22+
datafusion = "40.0.0"
2423
chrono = "0.4.38"
2524
bitpacking = "0.9.2"
2625
tokio = "1.38.1"
27-
futures = "0.3.30"

benchmarks/nyc_taxi_bench/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ debug = []
88

99
[dependencies]
1010
datafusion-uwheel = { path = "../../datafusion-uwheel" }
11-
datafusion = "38.0.0"
11+
datafusion = "40.0.0"
1212
mimalloc = { version = "*", default-features = false, optional = true }
1313
tokio = { version = "1", features = ["full"] }
1414
chrono = "0.4.38"

benchmarks/nyc_taxi_bench/src/main.rs

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use datafusion::datasource::file_format::parquet::ParquetFormat;
55
use datafusion::datasource::listing::{
66
ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
77
};
8-
use datafusion_uwheel::UWheelOptimizer;
8+
use datafusion_uwheel::{IndexBuilder, UWheelOptimizer};
99

1010
use chrono::{DateTime, NaiveDate, Utc};
1111
use clap::Parser;
@@ -45,7 +45,7 @@ async fn main() -> Result<()> {
4545
let filename = "../../data/yellow_tripdata_2022-01.parquet";
4646

4747
// register parquet file with the execution context
48-
ctx.register_parquet("yellow_tripdata", &filename, ParquetReadOptions::default())
48+
ctx.register_parquet("yellow_tripdata", filename, ParquetReadOptions::default())
4949
.await?;
5050

5151
// Create ctx with UWheelOptimizer
@@ -78,14 +78,24 @@ async fn main() -> Result<()> {
7878
Builder::new("tpep_dropoff_datetime")
7979
.with_name("yellow_tripdata")
8080
.with_min_max_wheels(vec!["fare_amount", "trip_distance"]) // Create Min/Max wheels for the columns "fare_amount" and "trip_distance"
81-
.with_sum_wheels(vec!["fare_amount"])
8281
.build_with_provider(provider)
8382
.await
8483
.unwrap(),
8584
);
8685

87-
// Set UWheelOptimizer as the query planner
88-
let session_state = uwheel_ctx.state().with_query_planner(optimizer.clone());
86+
// Build index on fare_amount using SUM as aggregate
87+
optimizer
88+
.build_index(IndexBuilder::with_col_and_aggregate(
89+
"fare_amount",
90+
datafusion_uwheel::AggregateType::Sum,
91+
))
92+
.await
93+
.unwrap();
94+
95+
// Set UWheelOptimizer as optimizer rule
96+
let session_state = uwheel_ctx
97+
.state()
98+
.with_optimizer_rules(vec![optimizer.clone()]);
8999
let uwheel_ctx = SessionContext::new_with_state(session_state);
90100

91101
// Register the table using the underlying provider
@@ -131,33 +141,33 @@ pub async fn bench(
131141
ranges: &[(u64, u64)],
132142
fares: &[f64],
133143
) {
134-
bench_datafusion_count("datafusion-count(*)", &ctx, &ranges).await;
135-
bench_datafusion_count("datafusion-uwheel-count(*)", &uwheel_ctx, &ranges).await;
144+
bench_datafusion_count("datafusion-count(*)", ctx, ranges).await;
145+
bench_datafusion_count("datafusion-uwheel-count(*)", uwheel_ctx, ranges).await;
136146

137-
bench_datafusion_sum_fare_amount("datafusion-sum(fare_amount)", &ctx, &ranges).await;
138-
bench_datafusion_sum_fare_amount("datafusion-uwheel-sum(fare_amount)", &uwheel_ctx, &ranges)
147+
bench_datafusion_sum_fare_amount("datafusion-sum(fare_amount)", ctx, ranges).await;
148+
bench_datafusion_sum_fare_amount("datafusion-uwheel-sum(fare_amount)", uwheel_ctx, ranges)
139149
.await;
140150

141151
bench_min_max_projection(
142152
"datafusion-select(*)-fare-amount-filter",
143-
&ctx,
144-
&ranges,
145-
&fares,
153+
ctx,
154+
ranges,
155+
fares,
146156
)
147157
.await;
148158
bench_min_max_projection(
149159
"datafusion-uwheel-select(*)-fare-amount-filter",
150-
&uwheel_ctx,
151-
&ranges,
152-
&fares,
160+
uwheel_ctx,
161+
ranges,
162+
fares,
153163
)
154164
.await;
155165

156-
bench_datafusion_temporal_projection("datafusion-select(*)-count-filter", &ctx, &ranges).await;
166+
bench_datafusion_temporal_projection("datafusion-select(*)-count-filter", ctx, ranges).await;
157167
bench_datafusion_temporal_projection(
158168
"datafusion-uwheel-select(*)-count-filter",
159-
&uwheel_ctx,
160-
&ranges,
169+
uwheel_ctx,
170+
ranges,
161171
)
162172
.await;
163173
}
@@ -244,13 +254,11 @@ async fn bench_datafusion_count(id: &str, ctx: &SessionContext, ranges: &[(u64,
244254
.map(|(start, end)| {
245255
let start = DateTime::from_timestamp_millis(start as i64)
246256
.unwrap()
247-
.to_utc()
248-
.naive_utc()
257+
.to_rfc3339()
249258
.to_string();
250259
let end = DateTime::from_timestamp_millis(end as i64)
251260
.unwrap()
252-
.to_utc()
253-
.naive_utc()
261+
.to_rfc3339()
254262
.to_string();
255263
format!(
256264
"SELECT COUNT(*) FROM yellow_tripdata \
@@ -298,18 +306,16 @@ async fn bench_datafusion_sum_fare_amount(id: &str, ctx: &SessionContext, ranges
298306
.map(|(start, end)| {
299307
let start = DateTime::from_timestamp_millis(start as i64)
300308
.unwrap()
301-
.to_utc()
302-
.naive_utc()
309+
.to_rfc3339()
303310
.to_string();
304311
let end = DateTime::from_timestamp_millis(end as i64)
305312
.unwrap()
306-
.to_utc()
307-
.naive_utc()
313+
.to_rfc3339()
308314
.to_string();
309315
format!(
310316
"SELECT SUM(fare_amount) FROM yellow_tripdata \
311-
WHERE tpep_dropoff_datetime >= '{}' \
312-
AND tpep_dropoff_datetime < '{}'",
317+
WHERE tpep_dropoff_datetime >= '{}' \
318+
AND tpep_dropoff_datetime < '{}'",
313319
start, end
314320
)
315321
})
@@ -358,13 +364,11 @@ async fn bench_min_max_projection(
358364
.map(|((start, end), fare)| {
359365
let start = DateTime::from_timestamp_millis(start as i64)
360366
.unwrap()
361-
.to_utc()
362-
.naive_utc()
367+
.to_rfc3339()
363368
.to_string();
364369
let end = DateTime::from_timestamp_millis(end as i64)
365370
.unwrap()
366-
.to_utc()
367-
.naive_utc()
371+
.to_rfc3339()
368372
.to_string();
369373
format!(
370374
"SELECT * FROM yellow_tripdata \
@@ -382,7 +386,7 @@ async fn bench_min_max_projection(
382386
// dbg!(&query);
383387
let now = Instant::now();
384388
let df = ctx.sql(&query).await.unwrap();
385-
let res = df.collect().await.unwrap();
389+
let _res = df.collect().await.unwrap();
386390
hist.record(now.elapsed().as_micros() as u64).unwrap();
387391
}
388392
let runtime = full.elapsed();
@@ -409,13 +413,11 @@ async fn bench_datafusion_temporal_projection(
409413
.map(|(start, end)| {
410414
let start = DateTime::from_timestamp_millis(start as i64)
411415
.unwrap()
412-
.to_utc()
413-
.naive_utc()
416+
.to_rfc3339()
414417
.to_string();
415418
let end = DateTime::from_timestamp_millis(end as i64)
416419
.unwrap()
417-
.to_utc()
418-
.naive_utc()
420+
.to_rfc3339()
419421
.to_string();
420422
format!(
421423
"SELECT * FROM yellow_tripdata \
@@ -432,7 +434,7 @@ async fn bench_datafusion_temporal_projection(
432434
// dbg!(&query);
433435
let now = Instant::now();
434436
let df = ctx.sql(&query).await.unwrap();
435-
let res = df.collect().await.unwrap();
437+
let _res = df.collect().await.unwrap();
436438
hist.record(now.elapsed().as_micros() as u64).unwrap();
437439
}
438440
let runtime = full.elapsed();

datafusion-uwheel/Cargo.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,8 @@ edition.workspace = true
66
[dependencies]
77
datafusion.workspace = true
88
uwheel.workspace = true
9-
async-trait.workspace = true
109
chrono.workspace = true
1110
bitpacking.workspace = true
12-
futures.workspace = true
13-
1411

1512
[dev-dependencies]
1613
tokio = "1.38.1"

datafusion-uwheel/examples/nyc_taxi.rs

Lines changed: 16 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,11 @@ use datafusion::{
77
listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl},
88
},
99
error::Result,
10-
physical_plan::{coalesce_batches::CoalesceBatchesExec, collect, empty::EmptyExec},
11-
prelude::{col, count, lit, SessionContext},
10+
physical_plan::collect,
11+
prelude::{col, lit, SessionContext},
1212
scalar::ScalarValue,
1313
};
14-
use datafusion_uwheel::{
15-
builder::Builder,
16-
exec::{UWheelCountExec, UWheelSumExec},
17-
AggregateType, IndexBuilder, UWheelOptimizer,
18-
};
14+
use datafusion_uwheel::{builder::Builder, AggregateType, IndexBuilder, UWheelOptimizer};
1915

2016
#[tokio::main(flavor = "current_thread")]
2117
async fn main() -> Result<()> {
@@ -47,41 +43,39 @@ async fn main() -> Result<()> {
4743
Builder::new("tpep_dropoff_datetime")
4844
.with_name("yellow_tripdata")
4945
.with_min_max_wheels(vec!["fare_amount", "trip_distance"])
50-
.with_sum_wheels(vec!["fare_amount"])
5146
.build_with_provider(provider)
5247
.await
5348
.unwrap(),
5449
);
5550

51+
// Build a wheel SUM on fare_amount
52+
let builder = IndexBuilder::with_col_and_aggregate("fare_amount", AggregateType::Sum);
53+
optimizer.build_index(builder).await?;
54+
5655
// Build a wheel for a custom expression
5756
let builder = IndexBuilder::with_col_and_aggregate("fare_amount", AggregateType::Sum)
5857
.with_filter(col("passenger_count").eq(lit(ScalarValue::Float64(Some(4.0)))));
5958

6059
optimizer.build_index(builder).await?;
6160

6261
// Set UWheelOptimizer as the query planner
63-
let session_state = ctx.state().with_query_planner(optimizer.clone());
62+
let session_state = ctx.state().with_optimizer_rules(vec![optimizer.clone()]);
6463
let ctx = SessionContext::new_with_state(session_state);
6564

6665
// Register the table using the underlying provider
6766
ctx.register_table("yellow_tripdata", optimizer.provider())
6867
.unwrap();
6968

7069
// This query will then use the UWheelOptimizer to execute
71-
let plan = ctx
70+
let df = ctx
7271
.sql(
7372
"SELECT COUNT(*) FROM yellow_tripdata
7473
WHERE tpep_dropoff_datetime >= '2022-01-01T00:00:00Z'
7574
AND tpep_dropoff_datetime < '2022-02-01T00:00:00Z'",
7675
)
77-
.await?
78-
.create_physical_plan()
7976
.await?;
8077

81-
// The plan should be a UWheelCountExec
82-
let uwheel_exec = plan.as_any().downcast_ref::<UWheelCountExec>().unwrap();
83-
dbg!(uwheel_exec);
84-
78+
let plan = df.create_physical_plan().await?;
8579
// Execute the plan
8680
let results: Vec<RecordBatch> = collect(plan, ctx.task_ctx()).await?;
8781
arrow::util::pretty::print_batches(&results).unwrap();
@@ -99,7 +93,7 @@ async fn main() -> Result<()> {
9993
let results: Vec<RecordBatch> = collect(sum_plan, ctx.task_ctx()).await?;
10094
arrow::util::pretty::print_batches(&results).unwrap();
10195

102-
let physical_plan = ctx
96+
let filter_plan = ctx
10397
.sql(
10498
"SELECT * FROM yellow_tripdata
10599
WHERE tpep_dropoff_datetime >= '2022-01-01T00:00:00Z'
@@ -109,23 +103,8 @@ async fn main() -> Result<()> {
109103
.create_physical_plan()
110104
.await?;
111105

112-
// Verify that the plan is optimized to EmptyExec
113-
assert!(physical_plan.as_any().downcast_ref::<EmptyExec>().is_some());
114-
115-
let physical_plan = ctx
116-
.sql(
117-
"SELECT * FROM yellow_tripdata
118-
WHERE tpep_dropoff_datetime >= '2022-01-01T00:00:00Z'
119-
AND tpep_dropoff_datetime < '2022-01-02T00:00:00Z'",
120-
)
121-
.await?
122-
.create_physical_plan()
123-
.await?;
124-
125-
assert!(physical_plan
126-
.as_any()
127-
.downcast_ref::<CoalesceBatchesExec>()
128-
.is_some());
106+
let results: Vec<RecordBatch> = collect(filter_plan, ctx.task_ctx()).await?;
107+
arrow::util::pretty::print_batches(&results).unwrap();
129108

130109
let min_max = ctx
131110
.sql(
@@ -137,25 +116,9 @@ async fn main() -> Result<()> {
137116
.await?
138117
.create_physical_plan()
139118
.await?;
140-
// verify that it returned an EmptyExec
141-
assert!(min_max.as_any().downcast_ref::<EmptyExec>().is_some());
142-
143-
let between_plan = ctx
144-
.sql(
145-
"SELECT COUNT(*) FROM yellow_tripdata
146-
WHERE tpep_dropoff_datetime BETWEEN '2022-01-01T00:00:00Z'
147-
AND '2022-02-01T00:00:00Z'",
148-
)
149-
.await?
150-
.create_physical_plan()
151-
.await?;
152-
153-
// The plan should be a UWheelCountExec
154-
let uwheel_exec = between_plan
155-
.as_any()
156-
.downcast_ref::<UWheelCountExec>()
157-
.unwrap();
158-
dbg!(uwheel_exec);
119+
let results: Vec<RecordBatch> = collect(min_max, ctx.task_ctx()).await?;
120+
assert!(results.is_empty());
121+
arrow::util::pretty::print_batches(&results).unwrap();
159122

160123
// We created an index for this SQL query earlier so it execute as a UWheelSumExec
161124
let sum_keyed_plan = ctx
@@ -169,11 +132,6 @@ async fn main() -> Result<()> {
169132
.create_physical_plan()
170133
.await?;
171134

172-
assert!(sum_keyed_plan
173-
.as_any()
174-
.downcast_ref::<UWheelSumExec>()
175-
.is_some());
176-
177135
let results: Vec<RecordBatch> = collect(sum_keyed_plan, ctx.task_ctx()).await?;
178136
arrow::util::pretty::print_batches(&results).unwrap();
179137

0 commit comments

Comments
 (0)