Skip to content

Commit 4e610a2

Browse files
authored
Fix f16 to f32 coercion to meet Trino f32 minimum (#64)
# Changes Fixes issue where f16 columns were incorrectly being casted to f32. Since Trino does not support float values with width < 32bits, this widens f16 columns to f32 when doing reads to prevent this issue. Stacktrace: ``` Caused by: java.lang.ClassCastException: class org.apache.arrow.vector.Float2Vector cannot be cast to class org.apache.arrow.vector.Float4Vector (org.apache.arrow.vector.Float2Vector and org.apache.arrow.vector.Float4Vector are in unnamed module of loader io.trino.server.PluginClassLoader @24e3d69f) at io.trino.plugin.lance.LanceArrowToPageScanner.lambda$convertType$6(LanceArrowToPageScanner.java:337) at io.trino.plugin.lance.LanceArrowToPageScanner.writeVectorValues(LanceArrowToPageScanner.java:425) at io.trino.plugin.lance.LanceArrowToPageScanner.convertType(LanceArrowToPageScanner.java:337) ... 45 more ```
1 parent 06c6d68 commit 4e610a2

7 files changed

Lines changed: 330 additions & 7 deletions

File tree

plugin/trino-lance/src/main/java/io/trino/plugin/lance/LanceArrowToPageScanner.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.apache.arrow.vector.BitVector;
3333
import org.apache.arrow.vector.DateDayVector;
3434
import org.apache.arrow.vector.FieldVector;
35+
import org.apache.arrow.vector.Float2Vector;
3536
import org.apache.arrow.vector.Float4Vector;
3637
import org.apache.arrow.vector.Float8Vector;
3738
import org.apache.arrow.vector.IntVector;
@@ -334,8 +335,15 @@ else if (type.equals(TIME_MICROS)) {
334335
}
335336
else if (type.equals(REAL)) {
336337
// REAL stores float bits as int which is widened to long
337-
writeVectorValues(output, vector, index -> type.writeLong(output,
338-
Float.floatToIntBits(((Float4Vector) vector).get(index))), offset, length);
338+
if (vector instanceof Float2Vector f2v) {
339+
// Widen float16 to float32 since Trino has no float16 type
340+
writeVectorValues(output, vector, index -> type.writeLong(output,
341+
Float.floatToIntBits(f2v.getValueAsFloat(index))), offset, length);
342+
}
343+
else {
344+
writeVectorValues(output, vector, index -> type.writeLong(output,
345+
Float.floatToIntBits(((Float4Vector) vector).get(index))), offset, length);
346+
}
339347
}
340348
else if (type instanceof TimestampWithTimeZoneType) {
341349
// Timestamp with timezone - stored as microseconds in Arrow

plugin/trino-lance/src/main/java/io/trino/plugin/lance/LanceColumnHandle.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,11 @@ else if (intType.getBitWidth() == 64) {
155155
}
156156
}
157157
else if (type instanceof ArrowType.FloatingPoint fpType) {
158-
if (fpType.getPrecision() == FloatingPointPrecision.SINGLE) {
159-
return REAL;
158+
if (fpType.getPrecision() == FloatingPointPrecision.DOUBLE) {
159+
return DOUBLE;
160160
}
161-
return DOUBLE;
161+
// SINGLE and HALF (float16) both map to REAL — Trino has no float16 type
162+
return REAL;
162163
}
163164
else if (type instanceof ArrowType.Utf8) {
164165
return VARCHAR;
Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.plugin.lance;
15+
16+
import com.google.common.collect.ImmutableMap;
17+
import com.google.common.io.Resources;
18+
import io.airlift.json.JsonCodec;
19+
import io.airlift.slice.Slice;
20+
import io.trino.spi.Page;
21+
import io.trino.spi.block.Block;
22+
import io.trino.spi.connector.ConnectorSplitSource;
23+
import io.trino.spi.connector.ConnectorTableHandle;
24+
import io.trino.spi.connector.SchemaTableName;
25+
import io.trino.spi.type.ArrayType;
26+
import io.trino.testing.TestingConnectorSession;
27+
import org.junit.jupiter.api.BeforeEach;
28+
import org.junit.jupiter.api.Test;
29+
import org.junit.jupiter.api.TestInstance;
30+
31+
import java.net.URL;
32+
import java.util.Collections;
33+
import java.util.List;
34+
import java.util.Map;
35+
import java.util.Optional;
36+
37+
import static io.trino.spi.type.BigintType.BIGINT;
38+
import static io.trino.spi.type.BooleanType.BOOLEAN;
39+
import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc;
40+
import static io.trino.spi.type.DateType.DATE;
41+
import static io.trino.spi.type.DoubleType.DOUBLE;
42+
import static io.trino.spi.type.IntegerType.INTEGER;
43+
import static io.trino.spi.type.RealType.REAL;
44+
import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS;
45+
import static io.trino.spi.type.VarbinaryType.VARBINARY;
46+
import static io.trino.spi.type.VarcharType.VARCHAR;
47+
import static org.assertj.core.api.Assertions.assertThat;
48+
import static org.assertj.core.api.Assertions.within;
49+
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_METHOD;
50+
51+
/**
52+
* Exercises every type path in LanceArrowToPageScanner.convertType using a wide dataset
53+
* (wide_types_table.lance). Each assertion targets a single (column, row) pair.
54+
*
55+
* Dataset schema and Arrow→Trino mappings:
56+
* id int64 → BIGINT (BigIntVector)
57+
* col_bool bool → BOOLEAN (BitVector)
58+
* col_int32 int32 → INTEGER (IntVector)
59+
* col_int64 int64 → BIGINT (BigIntVector, signed)
60+
* col_uint64 uint64 → BIGINT (UInt8Vector, unsigned)
61+
* col_float16 float16 → REAL (Float2Vector widened to float32)
62+
* col_float32 float32 → REAL (Float4Vector)
63+
* col_float64 float64 → DOUBLE (Float8Vector)
64+
* col_string utf8 → VARCHAR (VarCharVector)
65+
* col_binary binary → VARBINARY (VarBinaryVector)
66+
* col_date date32 → DATE (DateDayVector, days since epoch)
67+
* col_ts timestamp[us] → TIMESTAMP_TZ_MILLIS (Lance JNI promotes to UTC)
68+
* col_ts_tz timestamp[us, UTC] → TIMESTAMP_TZ_MILLIS (TimeStampMicroTZVector)
69+
* col_list_f32 list(float32) → ARRAY(REAL) (ListVector + Float4Vector)
70+
* col_fsl_f32 fsl(float32)[3] → ARRAY(REAL) (FixedSizeListVector + Float4Vector)
71+
* col_fsl_f16 fsl(float16)[3] → ARRAY(REAL) (FixedSizeListVector + Float2Vector widened)
72+
*/
73+
@TestInstance(PER_METHOD)
74+
public class TestLanceArrowToPageScanner
75+
{
76+
private static final SchemaTableName WIDE_TABLE = new SchemaTableName("default", "wide_types_table");
77+
78+
// 2024-01-15 10:30:00 UTC
79+
private static final long ROW0_DATE_DAYS = 19737L;
80+
private static final long ROW0_TS_MILLIS = 1705314600000L;
81+
82+
// 2024-06-30 20:00:00 UTC
83+
private static final long ROW1_DATE_DAYS = 19904L;
84+
private static final long ROW1_TS_MILLIS = 1719777600000L;
85+
86+
private LanceMetadata metadata;
87+
private LanceSplitManager splitManager;
88+
private LanceRuntime runtime;
89+
private Page page;
90+
private List<LanceColumnHandle> columns;
91+
92+
@BeforeEach
93+
public void setUp()
94+
throws Exception
95+
{
96+
URL lanceURL = Resources.getResource(TestLanceArrowToPageScanner.class, "/example_db");
97+
assertThat(lanceURL).describedAs("example db is null").isNotNull();
98+
LanceConfig lanceConfig = new LanceConfig().setSingleLevelNs(true);
99+
Map<String, String> catalogProperties = ImmutableMap.of("lance.root", lanceURL.toString());
100+
runtime = new LanceRuntime(lanceConfig, catalogProperties);
101+
JsonCodec<LanceCommitTaskData> commitTaskDataCodec = JsonCodec.jsonCodec(LanceCommitTaskData.class);
102+
JsonCodec<LanceMergeCommitData> mergeCommitDataCodec = JsonCodec.jsonCodec(LanceMergeCommitData.class);
103+
metadata = new LanceMetadata(runtime, lanceConfig, commitTaskDataCodec, mergeCommitDataCodec);
104+
splitManager = new LanceSplitManager(runtime, lanceConfig);
105+
106+
ConnectorTableHandle tableHandle = metadata.getTableHandle(
107+
TestingConnectorSession.SESSION, WIDE_TABLE, Optional.empty(), Optional.empty());
108+
LanceTableHandle lanceTableHandle = (LanceTableHandle) tableHandle;
109+
110+
ConnectorSplitSource splits = splitManager.getSplits(
111+
null, TestingConnectorSession.SESSION, tableHandle, null, null);
112+
LanceSplit lanceSplit = (LanceSplit) splits.getNextBatch(10).get().getSplits().get(0);
113+
114+
columns = runtime.getColumnHandleList(null, lanceTableHandle.getTablePath(), null, Collections.emptyMap());
115+
116+
try (LanceFragmentPageSource pageSource = new LanceFragmentPageSource(
117+
lanceTableHandle, columns, lanceSplit.getFragments(), Collections.emptyMap(), 8192, null, runtime)) {
118+
page = pageSource.getNextPage();
119+
}
120+
121+
assertThat(page).isNotNull();
122+
assertThat(page.getPositionCount()).isEqualTo(2);
123+
}
124+
125+
@Test
126+
public void testBigint()
127+
{
128+
assertBigint("id", 0, 1L);
129+
assertBigint("id", 1, 2L);
130+
}
131+
132+
@Test
133+
public void testBoolean()
134+
{
135+
assertBoolean("col_bool", 0, true);
136+
assertBoolean("col_bool", 1, false);
137+
}
138+
139+
@Test
140+
public void testInteger()
141+
{
142+
assertInteger("col_int32", 0, 10);
143+
assertInteger("col_int32", 1, -10);
144+
}
145+
146+
@Test
147+
public void testBigintSigned()
148+
{
149+
assertBigint("col_int64", 0, 100L);
150+
assertBigint("col_int64", 1, -100L);
151+
}
152+
153+
@Test
154+
public void testBigintUnsigned()
155+
{
156+
// uint64 stored in UInt8Vector, read as signed long
157+
assertBigint("col_uint64", 0, 42L);
158+
assertBigint("col_uint64", 1, 99L);
159+
}
160+
161+
@Test
162+
public void testFloat16WideningToReal()
163+
{
164+
// Float2Vector widened to REAL (float32)
165+
assertReal("col_float16", 0, 3.5f);
166+
assertReal("col_float16", 1, -3.5f);
167+
}
168+
169+
@Test
170+
public void testFloat32()
171+
{
172+
assertReal("col_float32", 0, 1.5f);
173+
assertReal("col_float32", 1, -1.5f);
174+
}
175+
176+
@Test
177+
public void testFloat64()
178+
{
179+
assertDouble("col_float64", 0, 2.5);
180+
assertDouble("col_float64", 1, -2.5);
181+
}
182+
183+
@Test
184+
public void testVarchar()
185+
{
186+
assertVarchar("col_string", 0, "hello");
187+
assertVarchar("col_string", 1, "world");
188+
}
189+
190+
@Test
191+
public void testVarbinary()
192+
{
193+
assertVarbinary("col_binary", 0, new byte[] {0x01, 0x02});
194+
assertVarbinary("col_binary", 1, new byte[] {0x03, 0x04});
195+
}
196+
197+
@Test
198+
public void testDate()
199+
{
200+
assertDate("col_date", 0, ROW0_DATE_DAYS);
201+
assertDate("col_date", 1, ROW1_DATE_DAYS);
202+
}
203+
204+
@Test
205+
public void testTimestampNoTz()
206+
{
207+
// Lance JNI promotes timestamp[us] (no TZ) to UTC internally → TIMESTAMP_TZ_MILLIS
208+
assertTimestampTzMillis("col_ts", 0, ROW0_TS_MILLIS);
209+
assertTimestampTzMillis("col_ts", 1, ROW1_TS_MILLIS);
210+
}
211+
212+
@Test
213+
public void testTimestampWithTz()
214+
{
215+
assertTimestampTzMillis("col_ts_tz", 0, ROW0_TS_MILLIS);
216+
assertTimestampTzMillis("col_ts_tz", 1, ROW1_TS_MILLIS);
217+
}
218+
219+
@Test
220+
public void testListFloat32()
221+
{
222+
assertArrayReal("col_list_f32", 0, new float[] {1.0f, 2.0f});
223+
assertArrayReal("col_list_f32", 1, new float[] {3.0f, 4.0f, 5.0f});
224+
}
225+
226+
@Test
227+
public void testFixedSizeListFloat32()
228+
{
229+
assertArrayReal("col_fsl_f32", 0, new float[] {1.0f, 2.0f, 3.0f});
230+
assertArrayReal("col_fsl_f32", 1, new float[] {4.0f, 5.0f, 6.0f});
231+
}
232+
233+
@Test
234+
public void testFixedSizeListFloat16WideningToReal()
235+
{
236+
// Float2Vector elements widened to REAL inside a FixedSizeListVector
237+
assertArrayReal("col_fsl_f16", 0, new float[] {7.0f, 8.0f, 9.0f});
238+
assertArrayReal("col_fsl_f16", 1, new float[] {10.0f, 11.0f, 12.0f});
239+
}
240+
241+
// --- per-value assertion helpers ---
242+
243+
private Block blockFor(String name)
244+
{
245+
int idx = columns.stream().map(LanceColumnHandle::name).toList().indexOf(name);
246+
assertThat(idx).describedAs("column not found: " + name).isGreaterThanOrEqualTo(0);
247+
return page.getBlock(idx);
248+
}
249+
250+
private void assertBigint(String col, int row, long expected)
251+
{
252+
assertThat(BIGINT.getLong(blockFor(col), row)).isEqualTo(expected);
253+
}
254+
255+
private void assertBoolean(String col, int row, boolean expected)
256+
{
257+
assertThat(BOOLEAN.getBoolean(blockFor(col), row)).isEqualTo(expected);
258+
}
259+
260+
private void assertInteger(String col, int row, int expected)
261+
{
262+
assertThat((int) INTEGER.getLong(blockFor(col), row)).isEqualTo(expected);
263+
}
264+
265+
private void assertReal(String col, int row, float expected)
266+
{
267+
float actual = Float.intBitsToFloat((int) REAL.getLong(blockFor(col), row));
268+
assertThat(actual).isCloseTo(expected, within(0.001f));
269+
}
270+
271+
private void assertDouble(String col, int row, double expected)
272+
{
273+
assertThat(DOUBLE.getDouble(blockFor(col), row)).isCloseTo(expected, within(0.001));
274+
}
275+
276+
private void assertVarchar(String col, int row, String expected)
277+
{
278+
assertThat(VARCHAR.getSlice(blockFor(col), row).toStringUtf8()).isEqualTo(expected);
279+
}
280+
281+
private void assertVarbinary(String col, int row, byte[] expected)
282+
{
283+
Slice slice = VARBINARY.getSlice(blockFor(col), row);
284+
assertThat(slice.getBytes()).isEqualTo(expected);
285+
}
286+
287+
private void assertDate(String col, int row, long expectedDays)
288+
{
289+
assertThat(DATE.getLong(blockFor(col), row)).isEqualTo(expectedDays);
290+
}
291+
292+
private void assertTimestampTzMillis(String col, int row, long expectedMillis)
293+
{
294+
long packed = TIMESTAMP_TZ_MILLIS.getLong(blockFor(col), row);
295+
assertThat(unpackMillisUtc(packed)).isEqualTo(expectedMillis);
296+
}
297+
298+
private void assertArrayReal(String col, int row, float[] expectedElements)
299+
{
300+
LanceColumnHandle handle = columns.stream()
301+
.filter(c -> c.name().equals(col))
302+
.findFirst()
303+
.orElseThrow(() -> new AssertionError("column not found: " + col));
304+
ArrayType arrayType = (ArrayType) handle.trinoType();
305+
Block inner = (Block) arrayType.getObject(blockFor(col), row);
306+
assertThat(inner.getPositionCount()).isEqualTo(expectedElements.length);
307+
for (int i = 0; i < expectedElements.length; i++) {
308+
float actual = Float.intBitsToFloat((int) REAL.getLong(inner, i));
309+
assertThat(actual).isCloseTo(expectedElements[i], within(0.001f));
310+
}
311+
}
312+
}

plugin/trino-lance/src/test/java/io/trino/plugin/lance/TestLanceMetadata.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,16 @@ public void testListTables()
162162
new SchemaTableName("default", "test_table2"),
163163
new SchemaTableName("default", "test_table3"),
164164
new SchemaTableName("default", "test_table4"),
165-
new SchemaTableName("default", "test_table5")));
165+
new SchemaTableName("default", "test_table5"),
166+
new SchemaTableName("default", "wide_types_table")));
166167

167168
// specific schema
168169
assertThat(ImmutableSet.copyOf(metadata.listTables(SESSION, Optional.of("default")))).isEqualTo(ImmutableSet.of(
169170
new SchemaTableName("default", "test_table1"),
170171
new SchemaTableName("default", "test_table2"),
171172
new SchemaTableName("default", "test_table3"),
172173
new SchemaTableName("default", "test_table4"),
173-
new SchemaTableName("default", "test_table5")));
174+
new SchemaTableName("default", "test_table5"),
175+
new SchemaTableName("default", "wide_types_table")));
174176
}
175177
}

0 commit comments

Comments
 (0)