Skip to content

Commit 16de220

Browse files
authored
apacheGH-47721: [C++][FlightRPC] Return ODBC Column Attribute from result set (apache#48050)
### Rationale for this change Implement ODBC to return column attribute values from the result set ### What changes are included in this PR? - SQLColAttribute & Tests ### Are these changes tested? Tested on local MSVC ### Are there any user-facing changes? N/A * GitHub Issue: apache#47721 Authored-by: Alina (Xi) Li <alina.li@improving.com> Signed-off-by: David Li <li.davidm96@gmail.com>
1 parent 625465b commit 16de220

7 files changed

Lines changed: 1365 additions & 54 deletions

File tree

cpp/src/arrow/flight/sql/column_metadata.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,10 @@ const char* ColumnMetadata::kIsSearchable = "ARROW:FLIGHT:SQL:IS_SEARCHABLE";
5858
const char* ColumnMetadata::kRemarks = "ARROW:FLIGHT:SQL:REMARKS";
5959

6060
ColumnMetadata::ColumnMetadata(
61-
std::shared_ptr<const arrow::KeyValueMetadata> metadata_map)
62-
: metadata_map_(std::move(metadata_map)) {}
61+
std::shared_ptr<const arrow::KeyValueMetadata> metadata_map) {
62+
metadata_map_ = std::move(metadata_map ? metadata_map
63+
: std::make_shared<arrow::KeyValueMetadata>());
64+
}
6365

6466
arrow::Result<std::string> ColumnMetadata::GetCatalogName() const {
6567
return metadata_map_->Get(kCatalogName);

cpp/src/arrow/flight/sql/odbc/odbc_api.cc

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1297,8 +1297,90 @@ SQLRETURN SQLColAttribute(SQLHSTMT stmt, SQLUSMALLINT record_number,
12971297
<< ", output_length: " << static_cast<const void*>(output_length)
12981298
<< ", numeric_attribute_ptr: "
12991299
<< static_cast<const void*>(numeric_attribute_ptr);
1300-
// GH-47721 TODO: Implement SQLColAttribute, pre-requisite requires SQLColumns
1301-
return SQL_INVALID_HANDLE;
1300+
1301+
using ODBC::ODBCDescriptor;
1302+
using ODBC::ODBCStatement;
1303+
return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
1304+
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);
1305+
ODBCDescriptor* ird = statement->GetIRD();
1306+
SQLINTEGER output_length_int;
1307+
switch (field_identifier) {
1308+
// Numeric attributes
1309+
// internal is SQLLEN, no conversion is needed
1310+
case SQL_DESC_DISPLAY_SIZE:
1311+
case SQL_DESC_OCTET_LENGTH: {
1312+
ird->GetField(record_number, field_identifier, numeric_attribute_ptr,
1313+
buffer_length, &output_length_int);
1314+
break;
1315+
}
1316+
// internal is SQLULEN, conversion is needed.
1317+
case SQL_COLUMN_LENGTH: // ODBC 2.0
1318+
case SQL_DESC_LENGTH: {
1319+
SQLULEN temp;
1320+
ird->GetField(record_number, field_identifier, &temp, buffer_length,
1321+
&output_length_int);
1322+
if (numeric_attribute_ptr) {
1323+
*numeric_attribute_ptr = static_cast<SQLLEN>(temp);
1324+
}
1325+
break;
1326+
}
1327+
// internal is SQLINTEGER, conversion is needed.
1328+
case SQL_DESC_AUTO_UNIQUE_VALUE:
1329+
case SQL_DESC_CASE_SENSITIVE:
1330+
case SQL_DESC_NUM_PREC_RADIX: {
1331+
SQLINTEGER temp;
1332+
ird->GetField(record_number, field_identifier, &temp, buffer_length,
1333+
&output_length_int);
1334+
if (numeric_attribute_ptr) {
1335+
*numeric_attribute_ptr = static_cast<SQLLEN>(temp);
1336+
}
1337+
break;
1338+
}
1339+
// internal is SQLSMALLINT, conversion is needed.
1340+
case SQL_DESC_CONCISE_TYPE:
1341+
case SQL_DESC_COUNT:
1342+
case SQL_DESC_FIXED_PREC_SCALE:
1343+
case SQL_DESC_TYPE:
1344+
case SQL_DESC_NULLABLE:
1345+
case SQL_COLUMN_PRECISION: // ODBC 2.0
1346+
case SQL_DESC_PRECISION:
1347+
case SQL_COLUMN_SCALE: // ODBC 2.0
1348+
case SQL_DESC_SCALE:
1349+
case SQL_DESC_SEARCHABLE:
1350+
case SQL_DESC_UNNAMED:
1351+
case SQL_DESC_UNSIGNED:
1352+
case SQL_DESC_UPDATABLE: {
1353+
SQLSMALLINT temp;
1354+
ird->GetField(record_number, field_identifier, &temp, buffer_length,
1355+
&output_length_int);
1356+
if (numeric_attribute_ptr) {
1357+
*numeric_attribute_ptr = static_cast<SQLLEN>(temp);
1358+
}
1359+
break;
1360+
}
1361+
// Character attributes
1362+
case SQL_DESC_BASE_COLUMN_NAME:
1363+
case SQL_DESC_BASE_TABLE_NAME:
1364+
case SQL_DESC_CATALOG_NAME:
1365+
case SQL_DESC_LABEL:
1366+
case SQL_DESC_LITERAL_PREFIX:
1367+
case SQL_DESC_LITERAL_SUFFIX:
1368+
case SQL_DESC_LOCAL_TYPE_NAME:
1369+
case SQL_DESC_NAME:
1370+
case SQL_DESC_SCHEMA_NAME:
1371+
case SQL_DESC_TABLE_NAME:
1372+
case SQL_DESC_TYPE_NAME:
1373+
ird->GetField(record_number, field_identifier, character_attribute_ptr,
1374+
buffer_length, &output_length_int);
1375+
break;
1376+
default:
1377+
throw DriverException("Invalid descriptor field", "HY091");
1378+
}
1379+
if (output_length) {
1380+
*output_length = static_cast<SQLSMALLINT>(output_length_int);
1381+
}
1382+
return SQL_SUCCESS;
1383+
});
13021384
}
13031385

13041386
SQLRETURN SQLGetTypeInfo(SQLHSTMT stmt, SQLSMALLINT data_type) {

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set_metadata.cc

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
#include "arrow/flight/sql/column_metadata.h"
2121
#include "arrow/flight/sql/odbc/odbc_impl/platform.h"
2222
#include "arrow/flight/sql/odbc/odbc_impl/util.h"
23+
#include "arrow/type_traits.h"
24+
#include "arrow/util/key_value_metadata.h"
2325

2426
#include <utility>
2527
#include "arrow/flight/sql/odbc/odbc_impl/exceptions.h"
@@ -40,12 +42,8 @@ constexpr int32_t DefaultDecimalPrecision = 38;
4042
constexpr int32_t DefaultLengthForVariableLengthColumns = 1024;
4143

4244
namespace {
43-
std::shared_ptr<const KeyValueMetadata> empty_metadata_map(new KeyValueMetadata);
44-
4545
inline ColumnMetadata GetMetadata(const std::shared_ptr<Field>& field) {
46-
const auto& metadata_map = field->metadata();
47-
48-
ColumnMetadata metadata(metadata_map ? metadata_map : empty_metadata_map);
46+
ColumnMetadata metadata(field->metadata());
4947
return metadata;
5048
}
5149

@@ -207,10 +205,14 @@ size_t FlightSqlResultSetMetadata::GetOctetLength(int column_position) {
207205
.value_or(DefaultLengthForVariableLengthColumns);
208206
}
209207

210-
std::string FlightSqlResultSetMetadata::GetTypeName(int column_position) {
208+
std::string FlightSqlResultSetMetadata::GetTypeName(int column_position,
209+
int16_t data_type) {
211210
ColumnMetadata metadata = GetMetadata(schema_->field(column_position - 1));
212211

213-
return metadata.GetTypeName().ValueOrElse([] { return ""; });
212+
return metadata.GetTypeName().ValueOrElse([data_type] {
213+
// If we get an empty type name, figure out the type name from the data_type.
214+
return util::GetTypeNameFromSqlDataType(data_type);
215+
});
214216
}
215217

216218
Updatability FlightSqlResultSetMetadata::GetUpdatable(int column_position) {
@@ -239,20 +241,14 @@ Searchability FlightSqlResultSetMetadata::IsSearchable(int column_position) {
239241

240242
bool FlightSqlResultSetMetadata::IsUnsigned(int column_position) {
241243
const std::shared_ptr<Field>& field = schema_->field(column_position - 1);
242-
243-
switch (field->type()->id()) {
244-
case Type::UINT8:
245-
case Type::UINT16:
246-
case Type::UINT32:
247-
case Type::UINT64:
248-
return true;
249-
default:
250-
return false;
251-
}
244+
arrow::Type::type type_id = field->type()->id();
245+
// non-decimal and non-numeric types are unsigned.
246+
return !arrow::is_decimal(type_id) &&
247+
(!arrow::is_numeric(type_id) || arrow::is_unsigned_integer(type_id));
252248
}
253249

254250
bool FlightSqlResultSetMetadata::IsFixedPrecScale(int column_position) {
255-
// TODO: Flight SQL column metadata does not have this, should we add to the spec?
251+
// Precision for Arrow data types are modifiable by the user
256252
return false;
257253
}
258254

cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_result_set_metadata.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class FlightSqlResultSetMetadata : public ResultSetMetadata {
7777

7878
size_t GetOctetLength(int column_position) override;
7979

80-
std::string GetTypeName(int column_position) override;
80+
std::string GetTypeName(int column_position, int16_t data_type) override;
8181

8282
Updatability GetUpdatable(int column_position) override;
8383

@@ -87,6 +87,7 @@ class FlightSqlResultSetMetadata : public ResultSetMetadata {
8787

8888
Searchability IsSearchable(int column_position) override;
8989

90+
/// \brief Return true if the column is unsigned or not numeric
9091
bool IsUnsigned(int column_position) override;
9192

9293
bool IsFixedPrecScale(int column_position) override;

cpp/src/arrow/flight/sql/odbc/odbc_impl/odbc_descriptor.cc

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,9 @@ void ODBCDescriptor::GetHeaderField(SQLSMALLINT field_identifier, SQLPOINTER val
276276
GetAttribute(rows_processed_ptr_, value, buffer_length, output_length);
277277
break;
278278
case SQL_DESC_COUNT: {
279-
GetAttribute(highest_one_based_bound_record_, value, buffer_length, output_length);
279+
// highest_one_based_bound_record_ equals number of records + 1
280+
GetAttribute(static_cast<SQLSMALLINT>(highest_one_based_bound_record_ - 1), value,
281+
buffer_length, output_length);
280282
break;
281283
}
282284
default:
@@ -310,54 +312,55 @@ void ODBCDescriptor::GetField(SQLSMALLINT record_number, SQLSMALLINT field_ident
310312
throw DriverException("Invalid descriptor index", "07009");
311313
}
312314

313-
// TODO: Restrict fields based on AppDescriptor IPD, and IRD.
315+
// GH-47867 TODO: Restrict fields based on AppDescriptor IPD, and IRD.
314316

317+
bool length_in_bytes = true;
315318
SQLSMALLINT zero_based_record = record_number - 1;
316319
const DescriptorRecord& record = records_[zero_based_record];
317320
switch (field_identifier) {
318321
case SQL_DESC_BASE_COLUMN_NAME:
319-
GetAttributeUTF8(record.base_column_name, value, buffer_length, output_length,
320-
GetDiagnostics());
322+
GetAttributeSQLWCHAR(record.base_column_name, length_in_bytes, value, buffer_length,
323+
output_length, GetDiagnostics());
321324
break;
322325
case SQL_DESC_BASE_TABLE_NAME:
323-
GetAttributeUTF8(record.base_table_name, value, buffer_length, output_length,
324-
GetDiagnostics());
326+
GetAttributeSQLWCHAR(record.base_table_name, length_in_bytes, value, buffer_length,
327+
output_length, GetDiagnostics());
325328
break;
326329
case SQL_DESC_CATALOG_NAME:
327-
GetAttributeUTF8(record.catalog_name, value, buffer_length, output_length,
328-
GetDiagnostics());
330+
GetAttributeSQLWCHAR(record.catalog_name, length_in_bytes, value, buffer_length,
331+
output_length, GetDiagnostics());
329332
break;
330333
case SQL_DESC_LABEL:
331-
GetAttributeUTF8(record.label, value, buffer_length, output_length,
332-
GetDiagnostics());
334+
GetAttributeSQLWCHAR(record.label, length_in_bytes, value, buffer_length,
335+
output_length, GetDiagnostics());
333336
break;
334337
case SQL_DESC_LITERAL_PREFIX:
335-
GetAttributeUTF8(record.literal_prefix, value, buffer_length, output_length,
336-
GetDiagnostics());
338+
GetAttributeSQLWCHAR(record.literal_prefix, length_in_bytes, value, buffer_length,
339+
output_length, GetDiagnostics());
337340
break;
338341
case SQL_DESC_LITERAL_SUFFIX:
339-
GetAttributeUTF8(record.literal_suffix, value, buffer_length, output_length,
340-
GetDiagnostics());
342+
GetAttributeSQLWCHAR(record.literal_suffix, length_in_bytes, value, buffer_length,
343+
output_length, GetDiagnostics());
341344
break;
342345
case SQL_DESC_LOCAL_TYPE_NAME:
343-
GetAttributeUTF8(record.local_type_name, value, buffer_length, output_length,
344-
GetDiagnostics());
346+
GetAttributeSQLWCHAR(record.local_type_name, length_in_bytes, value, buffer_length,
347+
output_length, GetDiagnostics());
345348
break;
346349
case SQL_DESC_NAME:
347-
GetAttributeUTF8(record.name, value, buffer_length, output_length,
348-
GetDiagnostics());
350+
GetAttributeSQLWCHAR(record.name, length_in_bytes, value, buffer_length,
351+
output_length, GetDiagnostics());
349352
break;
350353
case SQL_DESC_SCHEMA_NAME:
351-
GetAttributeUTF8(record.schema_name, value, buffer_length, output_length,
352-
GetDiagnostics());
354+
GetAttributeSQLWCHAR(record.schema_name, length_in_bytes, value, buffer_length,
355+
output_length, GetDiagnostics());
353356
break;
354357
case SQL_DESC_TABLE_NAME:
355-
GetAttributeUTF8(record.table_name, value, buffer_length, output_length,
356-
GetDiagnostics());
358+
GetAttributeSQLWCHAR(record.table_name, length_in_bytes, value, buffer_length,
359+
output_length, GetDiagnostics());
357360
break;
358361
case SQL_DESC_TYPE_NAME:
359-
GetAttributeUTF8(record.type_name, value, buffer_length, output_length,
360-
GetDiagnostics());
362+
GetAttributeSQLWCHAR(record.type_name, length_in_bytes, value, buffer_length,
363+
output_length, GetDiagnostics());
361364
break;
362365

363366
case SQL_DESC_DATA_PTR:
@@ -367,7 +370,7 @@ void ODBCDescriptor::GetField(SQLSMALLINT record_number, SQLSMALLINT field_ident
367370
case SQL_DESC_OCTET_LENGTH_PTR:
368371
GetAttribute(record.indicator_ptr, value, buffer_length, output_length);
369372
break;
370-
373+
case SQL_COLUMN_LENGTH: // ODBC 2.0
371374
case SQL_DESC_LENGTH:
372375
GetAttribute(record.length, value, buffer_length, output_length);
373376
break;
@@ -407,12 +410,14 @@ void ODBCDescriptor::GetField(SQLSMALLINT record_number, SQLSMALLINT field_ident
407410
case SQL_DESC_PARAMETER_TYPE:
408411
GetAttribute(record.param_type, value, buffer_length, output_length);
409412
break;
413+
case SQL_COLUMN_PRECISION: // ODBC 2.0
410414
case SQL_DESC_PRECISION:
411415
GetAttribute(record.precision, value, buffer_length, output_length);
412416
break;
413417
case SQL_DESC_ROWVER:
414418
GetAttribute(record.row_ver, value, buffer_length, output_length);
415419
break;
420+
case SQL_COLUMN_SCALE: // ODBC 2.0
416421
case SQL_DESC_SCALE:
417422
GetAttribute(record.scale, value, buffer_length, output_length);
418423
break;
@@ -479,6 +484,8 @@ void ODBCDescriptor::PopulateFromResultSetMetadata(ResultSetMetadata* rsmd) {
479484

480485
for (size_t i = 0; i < records_.size(); ++i) {
481486
size_t one_based_index = i + 1;
487+
int16_t concise_type = rsmd->GetConciseType(one_based_index);
488+
482489
records_[i].base_column_name = rsmd->GetBaseColumnName(one_based_index);
483490
records_[i].base_table_name = rsmd->GetBaseTableName(one_based_index);
484491
records_[i].catalog_name = rsmd->GetCatalogName(one_based_index);
@@ -489,9 +496,8 @@ void ODBCDescriptor::PopulateFromResultSetMetadata(ResultSetMetadata* rsmd) {
489496
records_[i].name = rsmd->GetName(one_based_index);
490497
records_[i].schema_name = rsmd->GetSchemaName(one_based_index);
491498
records_[i].table_name = rsmd->GetTableName(one_based_index);
492-
records_[i].type_name = rsmd->GetTypeName(one_based_index);
493-
records_[i].concise_type = GetSqlTypeForODBCVersion(
494-
rsmd->GetConciseType(one_based_index), is_2x_connection_);
499+
records_[i].type_name = rsmd->GetTypeName(one_based_index, concise_type);
500+
records_[i].concise_type = GetSqlTypeForODBCVersion(concise_type, is_2x_connection_);
495501
records_[i].data_ptr = nullptr;
496502
records_[i].indicator_ptr = nullptr;
497503
records_[i].display_size = rsmd->GetColumnDisplaySize(one_based_index);

cpp/src/arrow/flight/sql/odbc/odbc_impl/spi/result_set_metadata.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717

1818
#pragma once
1919

20-
#include "arrow/flight/sql/odbc/odbc_impl/types.h"
21-
2220
#include <string>
21+
#include "arrow/flight/sql/odbc/odbc_impl/types.h"
2322

2423
namespace arrow::flight::sql::odbc {
2524

@@ -143,8 +142,9 @@ class ResultSetMetadata {
143142

144143
/// \brief It returns the data type as a string.
145144
/// \param column_position [in] the position of the column, starting from 1.
145+
/// \param data_type [in] the data type of the column.
146146
/// \return the data type string.
147-
virtual std::string GetTypeName(int column_position) = 0;
147+
virtual std::string GetTypeName(int column_position, int16_t data_type) = 0;
148148

149149
/// \brief It returns a numeric values indicate the updatability of the
150150
/// column.

0 commit comments

Comments
 (0)