Skip to content

Commit 1393188

Browse files
ovralamb
authored andcommitted
ARROW-11221: [Rust] DF Implement GROUP BY support for Float32/Float64
Rust doesn't provide Eq, Hash for f32/f64 types inside stdlib, it's why I am using an external library called ordered-float which implements this traits. It's better to use external library instead of implementing own inside this repository. Closes #9175 from ovr/issue-11221 Authored-by: Dmitry Patsura <zaets28rus@gmail.com> Signed-off-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent 96430cc commit 1393188

File tree

6 files changed

+161
-8
lines changed

6 files changed

+161
-8
lines changed

rust/datafusion/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ tokio = { version = "0.2", features = ["macros", "blocking", "rt-core", "rt-thre
6363
log = "^0.4"
6464
md-5 = "^0.9.1"
6565
sha2 = "^0.9.1"
66+
ordered-float = "2.0"
6667

6768
[dev-dependencies]
6869
rand = "0.8"

rust/datafusion/src/physical_plan/group_scalar.rs

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,17 @@
1717

1818
//! Defines scalars used to construct groups, ex. in GROUP BY clauses.
1919
20+
use ordered_float::OrderedFloat;
2021
use std::convert::{From, TryFrom};
2122

2223
use crate::error::{DataFusionError, Result};
2324
use crate::scalar::ScalarValue;
2425

25-
/// Enumeration of types that can be used in a GROUP BY expression (all primitives except
26-
/// for floating point numerics)
26+
/// Enumeration of types that can be used in a GROUP BY expression
2727
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
2828
pub(crate) enum GroupByScalar {
29+
Float32(OrderedFloat<f32>),
30+
Float64(OrderedFloat<f64>),
2931
UInt8(u8),
3032
UInt16(u16),
3133
UInt32(u32),
@@ -44,6 +46,12 @@ impl TryFrom<&ScalarValue> for GroupByScalar {
4446

4547
fn try_from(scalar_value: &ScalarValue) -> Result<Self> {
4648
Ok(match scalar_value {
49+
ScalarValue::Float32(Some(v)) => {
50+
GroupByScalar::Float32(OrderedFloat::from(*v))
51+
}
52+
ScalarValue::Float64(Some(v)) => {
53+
GroupByScalar::Float64(OrderedFloat::from(*v))
54+
}
4755
ScalarValue::Int8(Some(v)) => GroupByScalar::Int8(*v),
4856
ScalarValue::Int16(Some(v)) => GroupByScalar::Int16(*v),
4957
ScalarValue::Int32(Some(v)) => GroupByScalar::Int32(*v),
@@ -53,7 +61,9 @@ impl TryFrom<&ScalarValue> for GroupByScalar {
5361
ScalarValue::UInt32(Some(v)) => GroupByScalar::UInt32(*v),
5462
ScalarValue::UInt64(Some(v)) => GroupByScalar::UInt64(*v),
5563
ScalarValue::Utf8(Some(v)) => GroupByScalar::Utf8(Box::new(v.clone())),
56-
ScalarValue::Int8(None)
64+
ScalarValue::Float32(None)
65+
| ScalarValue::Float64(None)
66+
| ScalarValue::Int8(None)
5767
| ScalarValue::Int16(None)
5868
| ScalarValue::Int32(None)
5969
| ScalarValue::Int64(None)
@@ -80,6 +90,8 @@ impl TryFrom<&ScalarValue> for GroupByScalar {
8090
impl From<&GroupByScalar> for ScalarValue {
8191
fn from(group_by_scalar: &GroupByScalar) -> Self {
8292
match group_by_scalar {
93+
GroupByScalar::Float32(v) => ScalarValue::Float32(Some((*v).into())),
94+
GroupByScalar::Float64(v) => ScalarValue::Float64(Some((*v).into())),
8395
GroupByScalar::Int8(v) => ScalarValue::Int8(Some(*v)),
8496
GroupByScalar::Int16(v) => ScalarValue::Int16(Some(*v)),
8597
GroupByScalar::Int32(v) => ScalarValue::Int32(Some(*v)),
@@ -101,6 +113,48 @@ mod tests {
101113

102114
use crate::error::{DataFusionError, Result};
103115

116+
macro_rules! scalar_eq_test {
117+
($TYPE:expr, $VALUE:expr) => {{
118+
let scalar_value = $TYPE($VALUE);
119+
let a = GroupByScalar::try_from(&scalar_value).unwrap();
120+
121+
let scalar_value = $TYPE($VALUE);
122+
let b = GroupByScalar::try_from(&scalar_value).unwrap();
123+
124+
assert_eq!(a, b);
125+
}};
126+
}
127+
128+
#[test]
129+
fn test_scalar_ne_non_std() -> Result<()> {
130+
// Test only Scalars with non native Eq, Hash
131+
scalar_eq_test!(ScalarValue::Float32, Some(1.0));
132+
scalar_eq_test!(ScalarValue::Float64, Some(1.0));
133+
134+
Ok(())
135+
}
136+
137+
macro_rules! scalar_ne_test {
138+
($TYPE:expr, $LVALUE:expr, $RVALUE:expr) => {{
139+
let scalar_value = $TYPE($LVALUE);
140+
let a = GroupByScalar::try_from(&scalar_value).unwrap();
141+
142+
let scalar_value = $TYPE($RVALUE);
143+
let b = GroupByScalar::try_from(&scalar_value).unwrap();
144+
145+
assert_ne!(a, b);
146+
}};
147+
}
148+
149+
#[test]
150+
fn test_scalar_eq_non_std() -> Result<()> {
151+
// Test only Scalars with non native Eq, Hash
152+
scalar_ne_test!(ScalarValue::Float32, Some(1.0), Some(2.0));
153+
scalar_ne_test!(ScalarValue::Float64, Some(1.0), Some(2.0));
154+
155+
Ok(())
156+
}
157+
104158
#[test]
105159
fn from_scalar_holding_none() -> Result<()> {
106160
let scalar_value = ScalarValue::Int8(None);
@@ -120,14 +174,14 @@ mod tests {
120174
#[test]
121175
fn from_scalar_unsupported() -> Result<()> {
122176
// Use any ScalarValue type not supported by GroupByScalar.
123-
let scalar_value = ScalarValue::Float32(Some(1.1));
177+
let scalar_value = ScalarValue::LargeUtf8(Some("1.1".to_string()));
124178
let result = GroupByScalar::try_from(&scalar_value);
125179

126180
match result {
127181
Err(DataFusionError::Internal(error_message)) => assert_eq!(
128182
error_message,
129183
String::from(
130-
"Cannot convert a ScalarValue with associated DataType Float32"
184+
"Cannot convert a ScalarValue with associated DataType LargeUtf8"
131185
)
132186
),
133187
_ => panic!("Unexpected result"),

rust/datafusion/src/physical_plan/hash_aggregate.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ use arrow::error::{ArrowError, Result as ArrowResult};
3535
use arrow::record_batch::RecordBatch;
3636
use arrow::{
3737
array::{
38-
ArrayRef, Int16Array, Int32Array, Int64Array, Int8Array, StringArray,
39-
UInt16Array, UInt32Array, UInt64Array, UInt8Array,
38+
ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
39+
Int8Array, StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
4040
},
4141
compute,
4242
};
@@ -48,6 +48,7 @@ use super::{
4848
};
4949
use ahash::RandomState;
5050
use hashbrown::HashMap;
51+
use ordered_float::OrderedFloat;
5152

5253
use arrow::array::{TimestampMicrosecondArray, TimestampNanosecondArray};
5354
use async_trait::async_trait;
@@ -685,6 +686,14 @@ fn create_batch_from_map(
685686
// 2.
686687
let mut groups = (0..num_group_expr)
687688
.map(|i| match &group_by_values[i] {
689+
GroupByScalar::Float32(n) => {
690+
Arc::new(Float32Array::from(vec![(*n).into()] as Vec<f32>))
691+
as ArrayRef
692+
}
693+
GroupByScalar::Float64(n) => {
694+
Arc::new(Float64Array::from(vec![(*n).into()] as Vec<f64>))
695+
as ArrayRef
696+
}
688697
GroupByScalar::Int8(n) => {
689698
Arc::new(Int8Array::from(vec![*n])) as ArrayRef
690699
}
@@ -776,6 +785,14 @@ pub(crate) fn create_group_by_values(
776785
for i in 0..group_by_keys.len() {
777786
let col = &group_by_keys[i];
778787
match col.data_type() {
788+
DataType::Float32 => {
789+
let array = col.as_any().downcast_ref::<Float32Array>().unwrap();
790+
vec[i] = GroupByScalar::Float32(OrderedFloat::from(array.value(row)))
791+
}
792+
DataType::Float64 => {
793+
let array = col.as_any().downcast_ref::<Float64Array>().unwrap();
794+
vec[i] = GroupByScalar::Float64(OrderedFloat::from(array.value(row)))
795+
}
779796
DataType::UInt8 => {
780797
let array = col.as_any().downcast_ref::<UInt8Array>().unwrap();
781798
vec[i] = GroupByScalar::UInt8(array.value(row))

rust/datafusion/src/physical_plan/hash_join.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
//! into a set of partitions.
2020
2121
use arrow::{
22-
array::{ArrayRef, UInt64Builder},
22+
array::{ArrayRef, Float32Array, Float64Array, UInt64Builder},
2323
compute,
2424
};
2525
use arrow::{
@@ -393,6 +393,14 @@ pub(crate) fn create_key(
393393
vec.clear();
394394
for col in group_by_keys {
395395
match col.data_type() {
396+
DataType::Float32 => {
397+
let array = col.as_any().downcast_ref::<Float32Array>().unwrap();
398+
vec.extend_from_slice(&array.value(row).to_le_bytes());
399+
}
400+
DataType::Float64 => {
401+
let array = col.as_any().downcast_ref::<Float64Array>().unwrap();
402+
vec.extend_from_slice(&array.value(row).to_le_bytes());
403+
}
396404
DataType::UInt8 => {
397405
let array = col.as_any().downcast_ref::<UInt8Array>().unwrap();
398406
vec.extend_from_slice(&array.value(row).to_le_bytes());
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
c1,c2
2+
0.00001,0.000000000001
3+
0.00002,0.000000000002
4+
0.00002,0.000000000002
5+
0.00003,0.000000000003
6+
0.00003,0.000000000003
7+
0.00003,0.000000000003
8+
0.00004,0.000000000004
9+
0.00004,0.000000000004
10+
0.00004,0.000000000004
11+
0.00004,0.000000000004
12+
0.00005,0.000000000005
13+
0.00005,0.000000000005
14+
0.00005,0.000000000005
15+
0.00005,0.000000000005
16+
0.00005,0.000000000005

rust/datafusion/tests/sql.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,48 @@ async fn csv_query_group_by_int_min_max() -> Result<()> {
345345
Ok(())
346346
}
347347

348+
#[tokio::test]
349+
async fn csv_query_group_by_float32() -> Result<()> {
350+
let mut ctx = ExecutionContext::new();
351+
register_aggregate_floats_csv(&mut ctx)?;
352+
353+
let sql =
354+
"SELECT COUNT(*) as cnt, c1 FROM aggregate_floats GROUP BY c1 ORDER BY cnt DESC";
355+
let actual = execute(&mut ctx, sql).await;
356+
357+
let expected = vec![
358+
vec!["5", "0.00005"],
359+
vec!["4", "0.00004"],
360+
vec!["3", "0.00003"],
361+
vec!["2", "0.00002"],
362+
vec!["1", "0.00001"],
363+
];
364+
assert_eq!(expected, actual);
365+
366+
Ok(())
367+
}
368+
369+
#[tokio::test]
370+
async fn csv_query_group_by_float64() -> Result<()> {
371+
let mut ctx = ExecutionContext::new();
372+
register_aggregate_floats_csv(&mut ctx)?;
373+
374+
let sql =
375+
"SELECT COUNT(*) as cnt, c2 FROM aggregate_floats GROUP BY c2 ORDER BY cnt DESC";
376+
let actual = execute(&mut ctx, sql).await;
377+
378+
let expected = vec![
379+
vec!["5", "0.000000000005"],
380+
vec!["4", "0.000000000004"],
381+
vec!["3", "0.000000000003"],
382+
vec!["2", "0.000000000002"],
383+
vec!["1", "0.000000000001"],
384+
];
385+
assert_eq!(expected, actual);
386+
387+
Ok(())
388+
}
389+
348390
#[tokio::test]
349391
async fn csv_query_group_by_two_columns() -> Result<()> {
350392
let mut ctx = ExecutionContext::new();
@@ -1325,6 +1367,21 @@ fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> {
13251367
Ok(())
13261368
}
13271369

1370+
fn register_aggregate_floats_csv(ctx: &mut ExecutionContext) -> Result<()> {
1371+
// It's not possible to use aggregate_test_100, not enought similar values to test grouping on floats
1372+
let schema = Arc::new(Schema::new(vec![
1373+
Field::new("c1", DataType::Float32, false),
1374+
Field::new("c2", DataType::Float64, false),
1375+
]));
1376+
1377+
ctx.register_csv(
1378+
"aggregate_floats",
1379+
"tests/aggregate_floats.csv",
1380+
CsvReadOptions::new().schema(&schema),
1381+
)?;
1382+
Ok(())
1383+
}
1384+
13281385
fn register_alltypes_parquet(ctx: &mut ExecutionContext) {
13291386
let testdata = arrow::util::test_util::parquet_test_data();
13301387
ctx.register_parquet(

0 commit comments

Comments
 (0)