-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Description
Describe the bug
We have a use case to provide multiple column values to a UDAF. UDAFs support one column input (unless I'm mistaken, I'm looking at this supporting one input data type. this has been resolved by #7096
To work around this we tried packing the columns into a struct column and passing that as input into the UDAF but we're seeing an error with both SQL API struct() builtin and the Expr API BuiltInScalarFunction::Struct
To Reproduce
run the tests below and see following output
Failures
failures:
---- tests::test_udaf_pack_many_col_struct_sql stdout ----
Error: type_coercion
caused by
Error during planning: Coercion from [Struct([Field { name: "c0", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c2", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])] to the signature Exact([Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c2", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])]) failed.
Caused by:
Error during planning: Coercion from [Struct([Field { name: "c0", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c2", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])] to the signature Exact([Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c2", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])]) failed.
---- tests::test_udaf_pack_many_col_struct_expr stdout ----
Error: type_coercion
caused by
Error during planning: Coercion from [Struct([Field { name: "c0", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])] to the signature Exact([Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])]) failed.
Caused by:
Error during planning: Coercion from [Struct([Field { name: "c0", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])] to the signature Exact([Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])]) failed.
Table
cargo test -- --nocapture shows the table
+----+-------+------------+-------------------+---------+
| a | b | c | d | e |
+----+-------+------------+-------------------+---------+
| 12 | true | hi | {i: 12, j: true} | {i: 12} |
| 11 | false | datafusion | {i: 11, j: false} | {i: 11} |
+----+-------+------------+-------------------+---------+
Tests
use datafusion::{physical_plan::Accumulator, scalar::ScalarValue};
#[tokio::main]
async fn main() {}
#[derive(Default, Debug)]
struct SumUdaf {
sum: u32,
}
impl Accumulator for SumUdaf {
fn update_batch(&mut self, values: &[arrow::array::ArrayRef]) -> datafusion::error::Result<()> {
if values.is_empty() {
return Ok(());
}
let arr = &values[0];
(0..arr.len()).try_for_each(|index| {
let sv = ScalarValue::try_from_array(&arr, index)?;
if let ScalarValue::Struct(Some(values), _) = sv {
for v in values {
if let ScalarValue::Int32(Some(v)) = v {
self.sum += v as u32;
}
}
} else if let ScalarValue::Int32(Some(v)) = sv {
self.sum += v as u32;
}
Ok(())
})
}
fn evaluate(&self) -> datafusion::error::Result<ScalarValue> {
Ok(ScalarValue::from(self.sum))
}
fn size(&self) -> usize {
std::mem::size_of_val(self)
}
fn state(&self) -> datafusion::error::Result<Vec<ScalarValue>> {
Ok(vec![ScalarValue::from(self.sum)])
}
fn merge_batch(&mut self, states: &[arrow::array::ArrayRef]) -> datafusion::error::Result<()> {
if states.is_empty() {
return Ok(());
}
let arr = &states[0];
(0..arr.len()).try_for_each(|index| {
if let ScalarValue::UInt32(Some(v)) = ScalarValue::try_from_array(arr, index)? {
self.sum += v;
} else {
unreachable!("")
}
Ok(())
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use arrow::{
array::{
downcast_array, ArrayBuilder, BooleanBuilder, Int32Builder, StringBuilder,
StructBuilder, UInt32Array,
},
datatypes::{
DataType as ArrowDataType, Field as ArrowField, FieldRef as ArrowFieldRef,
Fields as ArrowFields, Schema as ArrowSchema,
},
record_batch::RecordBatch,
};
use datafusion::{
logical_expr::{expr::ScalarFunction, AggregateUDF, BuiltinScalarFunction},
prelude::*,
};
fn test_data() -> anyhow::Result<RecordBatch> {
let d_fields: Vec<ArrowFieldRef> = vec![
Arc::new(ArrowField::new("i", ArrowDataType::Int32, false)),
Arc::new(ArrowField::new("j", ArrowDataType::Boolean, false)),
];
let e_fields: Vec<ArrowFieldRef> =
vec![Arc::new(ArrowField::new("i", ArrowDataType::Int32, false))];
let schema = ArrowSchema::new(vec![
ArrowField::new("a", ArrowDataType::Int32, false),
ArrowField::new("b", ArrowDataType::Boolean, false),
ArrowField::new("c", ArrowDataType::Utf8, false),
ArrowField::new_struct("d", &*d_fields, false),
ArrowField::new_struct("e", &*e_fields, false),
]);
let mut a_builder = Int32Builder::new();
let mut b_builder = BooleanBuilder::new();
let mut c_builder = StringBuilder::new();
a_builder.append_values(&[12, 11], &[true, true]);
b_builder.append_values(&[true, false], &[true, true])?;
c_builder.append_value("hi");
c_builder.append_value("datafusion");
let struct_builders: Vec<Box<dyn ArrayBuilder>> = vec![
Box::new(Int32Builder::new()),
Box::new(BooleanBuilder::new()),
];
let mut d_builder = StructBuilder::new(d_fields, struct_builders);
d_builder.append(true);
d_builder.append(true);
let i_builder = d_builder
.field_builder::<Int32Builder>(0)
.ok_or_else(|| anyhow::anyhow!("bad builder"))?;
i_builder.append_value(12);
i_builder.append_value(11);
let j_builder = d_builder
.field_builder::<BooleanBuilder>(1)
.ok_or_else(|| anyhow::anyhow!("bad builder"))?;
j_builder.append_value(true);
j_builder.append_value(false);
let mut e_builder = StructBuilder::new(e_fields, vec![Box::new(Int32Builder::new())]);
e_builder.append(true);
e_builder.append(true);
let i_builder = e_builder
.field_builder::<Int32Builder>(0)
.ok_or_else(|| anyhow::anyhow!("bad builder"))?;
i_builder.append_value(12);
i_builder.append_value(11);
let mut builders: Vec<Box<dyn ArrayBuilder>> = vec![
Box::new(a_builder),
Box::new(b_builder),
Box::new(c_builder),
Box::new(d_builder),
Box::new(e_builder),
];
let arrays = builders.iter_mut().map(|b| b.finish()).collect::<Vec<_>>();
let batch = RecordBatch::try_new(Arc::new(schema), arrays)?;
Ok(batch)
}
async fn sql(
sql: impl AsRef<str>,
udaf_input_type: ArrowDataType,
) -> anyhow::Result<DataFrame> {
let ctx = SessionContext::default();
let batch = test_data()?;
ctx.register_batch("batch", batch)?;
ctx.register_udaf(udaf(udaf_input_type));
let df = ctx.sql(sql.as_ref()).await?;
Ok(df)
}
fn dataframe() -> anyhow::Result<DataFrame> {
let ctx = SessionContext::default();
let batch = test_data()?;
let df = ctx.read_batch(batch)?;
Ok(df)
}
fn udaf(input_type: ArrowDataType) -> AggregateUDF {
create_udaf(
"my_sum",
input_type,
Arc::new(ArrowDataType::UInt32),
datafusion::logical_expr::Volatility::Immutable,
Arc::new(|_| Ok(Box::new(SumUdaf::default()))),
Arc::new(vec![ArrowDataType::UInt32]),
)
}
fn pack_cols(cols: Vec<impl Into<Column>>) -> Expr {
Expr::ScalarFunction(ScalarFunction {
fun: BuiltinScalarFunction::Struct,
args: cols.into_iter().map(|c| col(c)).collect::<Vec<_>>(),
})
}
async fn assert(df: DataFrame, expected: u32) -> anyhow::Result<()> {
let result = df.collect().await?;
let result_arr = result[0].column(0);
let result_arr = downcast_array::<UInt32Array>(result_arr);
let actual = result_arr.value(0);
assert_eq!(expected, actual);
Ok(())
}
#[tokio::test]
async fn test_show() -> anyhow::Result<()> {
let df = dataframe()?;
df.show().await?;
Ok(())
}
#[tokio::test]
async fn test_udaf_existing_struct_many_col_sql() -> anyhow::Result<()> {
let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([
ArrowField::new("i", ArrowDataType::Int32, false),
ArrowField::new("j", ArrowDataType::Boolean, false),
]));
let df = sql("SELECT my_sum(d) FROM batch", udaf_input_type).await?;
assert(df, 23).await?;
Ok(())
}
#[tokio::test]
async fn test_udaf_existing_struct_many_col_expr() -> anyhow::Result<()> {
let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([
ArrowField::new("i", ArrowDataType::Int32, false),
ArrowField::new("j", ArrowDataType::Boolean, false),
]));
let df = dataframe()?;
let udaf = udaf(udaf_input_type);
let df = df.aggregate(vec![], vec![(udaf.call(vec![col("d")]))])?;
assert(df, 23).await?;
Ok(())
}
#[tokio::test]
async fn test_udaf_existing_struct_one_col_sql() -> anyhow::Result<()> {
let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([
ArrowField::new("i", ArrowDataType::Int32, false),
]));
let df = sql("SELECT my_sum(e) FROM batch", udaf_input_type).await?;
assert(df, 23).await?;
Ok(())
}
#[tokio::test]
async fn test_udaf_existing_struct_one_col_expr() -> anyhow::Result<()> {
let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([
ArrowField::new("i", ArrowDataType::Int32, false),
]));
let df = dataframe()?;
let udaf = udaf(udaf_input_type);
let df = df.aggregate(vec![], vec![(udaf.call(vec![col("e")]))])?;
assert(df, 23).await?;
Ok(())
}
#[tokio::test]
async fn test_udaf_pack_one_col_struct_sql() -> anyhow::Result<()> {
let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([
ArrowField::new("c0", ArrowDataType::Int32, true),
// ArrowField::new("c1", ArrowDataType::Boolean, true),
//ArrowField::new("c2", ArrowDataType::Utf8, true),
]));
let df = sql("SELECT my_sum(struct(a)) FROM batch", udaf_input_type).await?;
assert(df, 23).await?;
Ok(())
}
// FAILS - Treats all struct fields as Utf8
#[tokio::test]
async fn test_udaf_pack_many_col_struct_sql() -> anyhow::Result<()> {
let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([
ArrowField::new("c0", ArrowDataType::Int32, true),
ArrowField::new("c1", ArrowDataType::Boolean, true),
ArrowField::new("c2", ArrowDataType::Utf8, true),
]));
let df = sql("SELECT my_sum(struct(a, b, c)) FROM batch", udaf_input_type).await?;
assert(df, 23).await?;
Ok(())
}
#[tokio::test]
async fn test_udaf_pack_one_col_struct_expr() -> anyhow::Result<()> {
let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([ArrowField::new(
"c0",
ArrowDataType::Int32,
true,
)]));
let df = dataframe()?;
let udaf = udaf(udaf_input_type);
let packed_expr = pack_cols(vec!["a"]);
let df = df.aggregate(vec![], vec![udaf.call(vec![packed_expr])])?;
assert(df, 23).await?;
Ok(())
}
// FAILS - Treats all struct fields as Utf8
#[tokio::test]
async fn test_udaf_pack_many_col_struct_expr() -> anyhow::Result<()> {
let udaf_input_type = ArrowDataType::Struct(ArrowFields::from_iter([
ArrowField::new("c0", ArrowDataType::Int32, true),
ArrowField::new("c1", ArrowDataType::Boolean, true),
]));
let df = dataframe()?;
let udaf = udaf(udaf_input_type);
let packed_expr = pack_cols(vec!["a", "b"]);
let df = df.aggregate(vec![], vec![udaf.call(vec![packed_expr])])?;
assert(df, 23).await?;
Ok(())
}
}Expected behavior
We are able to create a struct with multiple fields using SQL API struct() builtin or Expr API's BuiltInScalarFunction::Struct and provide that as input to UDAF.
Additional context
The UDAF here is very simple just for example.
Is there a limitation with UDAF or could we open an enhancement request to support multiple input columns?