Skip to content

Commit efb3075

Browse files
[Data] Fixing remaining issues with custom tensor extensions (#56918)
<!-- Thank you for your contribution! Please review https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? While resolving that surfaced recently, more issues have come up which prompted me to review implementations of our tensor (Arrow's) extensions and address a variety of issues discovered: 1. Added missing `ArrowVariableShapedTensorType.__eq__ ` (to make sure we can concat blocks holding these) 2. Fixed concatenation of AVSTT to properly reconcile different dimensions 3. Cleaned up and abstracted common utils to unify tensor types and provided arrays 4. Replaced Python arrays w/ ndarrays wherever possible 5. Deleted a lot of dead code 6. Rebased `ExtensionArray.from_storage` w/ `ExtensionType.wrap_array` ## Related issue number <!-- For example: "Closes #1234" --> ## Checks - [ ] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [ ] I've run `scripts/format.sh` to lint the changes in this PR. - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [ ] Unit tests - [ ] Release tests - [ ] This PR is not tested :( <!-- CURSOR_SUMMARY --> --- > [!NOTE] > Refactors Ray Data’s Arrow tensor extensions with type unification and zero-copy concat, replaces legacy APIs with wrap_array, and enforces PyArrow>=9 across codepaths with updated concat/schema alignment and tests. > > - **Tensor Extensions (Arrow)**: > - Introduce `unify_tensor_types`, `unify_tensor_arrays`, and `concat_tensor_arrays` with zero-copy helpers (`_concat_ndarrays`, `_are_contiguous_1d_views`). > - Add robust equality/hash for `ArrowTensorType`/`V2`/`ArrowVariableShapedTensorType`; simplify `ArrowTensorScalar`. > - Replace `ExtensionArray.from_storage(...)` with `ExtensionType.wrap_array(...)`; remove `_concat_same_type` and old chunking utilities. > - Add `to_var_shaped_tensor_array` and shape-padding utilities; optimize `to_numpy`/`from_numpy` and boolean handling. > - **PyArrow Version Enforcement**: > - Add `_check_pyarrow_version` (min `9.0.0`, env override) in `ray/_private/arrow_utils.py`; integrate across Data (object/tensor extensions, util proxy). > - Update tests to validate failure on `pyarrow==8.0.0`; remove version-spoofing fixtures. > - **Arrow Ops & Schema Handling**: > - Update `concat`, schema unification, and struct-field alignment to use tensor-type unification; improved error messages. > - Use `concat_tensor_arrays` in extension column combining. > - **Other**: > - Simplify tensor scalar extraction in Arrow block accessor. > - Tests: add thorough tensor equality/concat/zero-copy cases; set `preserve_order` in limit/split tests; adjust bazel test size for `test_consumption`. > > <sup>Written by [Cursor Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit 529de7a. This will update automatically on new commits. Configure [here](https://cursor.com/dashboard?tab=bugbot).</sup> <!-- /CURSOR_SUMMARY --> --------- Signed-off-by: Alexey Kudinkin <ak@anyscale.com>
1 parent 394b97c commit efb3075

16 files changed

Lines changed: 878 additions & 644 deletions

File tree

ci/lint/pydoclint-baseline.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,6 @@ python/ray/air/util/tensor_extensions/arrow.py
415415
DOC101: Function `pyarrow_table_from_pydict`: Docstring contains fewer arguments than in function signature.
416416
DOC103: Function `pyarrow_table_from_pydict`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [pydict: Dict[str, Union[List[Any], pa.Array]]].
417417
DOC201: Function `pyarrow_table_from_pydict` does not have a return section in docstring
418-
DOC201: Method `ArrowTensorArray._concat_same_type` does not have a return section in docstring
419418
--------------------
420419
python/ray/air/util/tensor_extensions/pandas.py
421420
DOC101: Method `TensorDtype.__init__`: Docstring contains fewer arguments than in function signature.

doc/source/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,8 @@ def filter(self, record):
672672
"ray._raylet",
673673
"ray.core.generated",
674674
"ray.serve.generated",
675+
"ray.air.util.tensor_extensions",
676+
"ray.data._internal.arrow_ops",
675677
]
676678

677679
for mock_target in autodoc_mock_imports:

python/ray/_private/arrow_utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,56 @@
11
import json
2+
import logging
3+
import os
24
from typing import Dict, Optional
35
from urllib.parse import parse_qsl, unquote, urlencode, urlparse, urlunparse
46

57
from packaging.version import Version, parse as parse_version
68

9+
_RAY_DISABLE_PYARROW_VERSION_CHECK = "RAY_DISABLE_PYARROW_VERSION_CHECK"
10+
11+
712
_PYARROW_INSTALLED: Optional[bool] = None
813
_PYARROW_VERSION: Optional[Version] = None
914

1015

16+
# NOTE: Make sure that these lower and upper bounds stay in sync with version
17+
# constraints given in python/setup.py.
18+
# Inclusive minimum pyarrow version.
19+
_PYARROW_SUPPORTED_VERSION_MIN = "9.0.0"
20+
_PYARROW_VERSION_VALIDATED = False
21+
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
def _check_pyarrow_version():
27+
"""Checks that Pyarrow's version is within the supported bounds."""
28+
global _PYARROW_VERSION_VALIDATED
29+
30+
if not _PYARROW_VERSION_VALIDATED:
31+
if os.environ.get(_RAY_DISABLE_PYARROW_VERSION_CHECK, "0") == "1":
32+
_PYARROW_VERSION_VALIDATED = True
33+
return
34+
35+
version = get_pyarrow_version()
36+
if version is not None:
37+
if version < parse_version(_PYARROW_SUPPORTED_VERSION_MIN):
38+
raise ImportError(
39+
f"Dataset requires pyarrow >= {_PYARROW_SUPPORTED_VERSION_MIN}, but "
40+
f"{version} is installed. Reinstall with "
41+
f'`pip install -U "pyarrow"`. '
42+
)
43+
else:
44+
logger.warning(
45+
"You are using the 'pyarrow' module, but the exact version is unknown "
46+
"(possibly carried as an internal component by another module). Please "
47+
f"make sure you are using pyarrow >= {_PYARROW_SUPPORTED_VERSION_MIN} to ensure "
48+
"compatibility with Ray Dataset. "
49+
)
50+
51+
_PYARROW_VERSION_VALIDATED = True
52+
53+
1154
def get_pyarrow_version() -> Optional[Version]:
1255
"""Get the version of the pyarrow package or None if not installed."""
1356
global _PYARROW_INSTALLED, _PYARROW_VERSION

python/ray/air/tests/test_tensor_extension.py

Lines changed: 295 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
ArrowTensorTypeV2,
1515
ArrowVariableShapedTensorArray,
1616
ArrowVariableShapedTensorType,
17+
_are_contiguous_1d_views,
18+
_concat_ndarrays,
19+
_extension_array_concat_supported,
20+
concat_tensor_arrays,
21+
unify_tensor_arrays,
1722
)
1823
from ray.air.util.tensor_extensions.pandas import TensorArray, TensorDtype
1924
from ray.air.util.tensor_extensions.utils import create_ragged_ndarray
@@ -715,7 +720,7 @@ def test_arrow_tensor_array_concat(a1, a2, restore_data_context, tensor_format):
715720

716721
ta1 = ArrowTensorArray.from_numpy(a1)
717722
ta2 = ArrowTensorArray.from_numpy(a2)
718-
ta = ArrowTensorArray._concat_same_type([ta1, ta2])
723+
ta = concat_tensor_arrays([ta1, ta2])
719724
assert len(ta) == a1.shape[0] + a2.shape[0]
720725
if a1.shape[1:] == a2.shape[1:]:
721726
if tensor_format == "v1":
@@ -753,8 +758,8 @@ def test_variable_shaped_tensor_array_chunked_concat(
753758
a2 = np.arange(np.prod(shape2)).reshape(shape2)
754759
ta1 = ArrowTensorArray.from_numpy(a1)
755760
ta2 = ArrowTensorArray.from_numpy(a2)
756-
chunked_ta = ArrowTensorArray._chunk_tensor_arrays([ta1, ta2])
757-
ta = ArrowTensorArray._concat_same_type(chunked_ta.chunks)
761+
unified_arrs = unify_tensor_arrays([ta1, ta2])
762+
ta = concat_tensor_arrays(unified_arrs)
758763
assert len(ta) == shape1[0] + shape2[0]
759764
assert isinstance(ta.type, ArrowVariableShapedTensorType)
760765
assert pa.types.is_struct(ta.type.storage_type)
@@ -828,6 +833,293 @@ def test_tensor_array_string_tensors_simple(restore_data_context, tensor_format)
828833
np.testing.assert_array_equal(roundtrip_strings, string_tensors)
829834

830835

836+
def test_tensor_type_equality_checks():
837+
# Test that different types are not equal
838+
fs_tensor_type_v1 = ArrowTensorType((2, 3), pa.int64())
839+
fs_tensor_type_v2 = ArrowTensorTypeV2((2, 3), pa.int64())
840+
841+
assert fs_tensor_type_v1 != fs_tensor_type_v2
842+
843+
# Test different shapes/dtypes aren't equal
844+
assert fs_tensor_type_v1 != ArrowTensorType((3, 3), pa.int64())
845+
assert fs_tensor_type_v1 != ArrowTensorType((2, 3), pa.float64())
846+
assert fs_tensor_type_v2 != ArrowTensorTypeV2((3, 3), pa.int64())
847+
assert fs_tensor_type_v2 != ArrowTensorTypeV2((2, 3), pa.float64())
848+
849+
# Test var-shaped tensor type
850+
vs_tensor_type = ArrowVariableShapedTensorType(pa.int64(), 2)
851+
852+
# Test that different types are not equal
853+
assert vs_tensor_type != ArrowVariableShapedTensorType(pa.int64(), 3)
854+
assert vs_tensor_type != ArrowVariableShapedTensorType(pa.float64(), 2)
855+
assert vs_tensor_type != fs_tensor_type_v1
856+
assert vs_tensor_type != fs_tensor_type_v2
857+
858+
859+
@pytest.mark.skipif(
860+
not _extension_array_concat_supported(),
861+
reason="ExtensionArrays support concatenation only in Pyarrow >= 12.0",
862+
)
863+
def test_arrow_fixed_shape_tensor_type_eq_with_concat(restore_data_context):
864+
"""Test that ArrowTensorType and ArrowTensorTypeV2 __eq__ methods work correctly
865+
when concatenating Arrow arrays with the same tensor type."""
866+
from ray.data.context import DataContext
867+
from ray.data.extensions.tensor_extension import (
868+
ArrowTensorArray,
869+
ArrowTensorType,
870+
ArrowTensorTypeV2,
871+
)
872+
873+
# Test ArrowTensorType V1
874+
tensor_type_v1 = ArrowTensorType((2, 3), pa.int64())
875+
876+
DataContext.get_current().use_arrow_tensor_v2 = False
877+
first = ArrowTensorArray.from_numpy(np.ones((2, 2, 3), dtype=np.int64))
878+
second = ArrowTensorArray.from_numpy(np.zeros((3, 2, 3), dtype=np.int64))
879+
880+
assert first.type == second.type
881+
# Assert commutation
882+
assert tensor_type_v1 == first.type
883+
assert first.type == tensor_type_v1
884+
885+
# Test concatenation works appropriately
886+
concatenated = pa.concat_arrays([first, second])
887+
assert len(concatenated) == 5
888+
assert concatenated.type == tensor_type_v1
889+
890+
expected = np.vstack([first.to_numpy(), second.to_numpy()])
891+
np.testing.assert_array_equal(concatenated.to_numpy(), expected)
892+
893+
# Test ArrowTensorTypeV2
894+
tensor_type_v2 = ArrowTensorTypeV2((2, 3), pa.int64())
895+
896+
DataContext.get_current().use_arrow_tensor_v2 = True
897+
898+
first = ArrowTensorArray.from_numpy(np.ones((2, 2, 3), dtype=np.int64))
899+
second = ArrowTensorArray.from_numpy(np.ones((3, 2, 3), dtype=np.int64))
900+
901+
assert first.type == second.type
902+
# Assert commutation
903+
assert tensor_type_v2 == first.type
904+
assert first.type == tensor_type_v2
905+
906+
# Test concatenation works appropriately
907+
concatenated_v2 = pa.concat_arrays([first, second])
908+
assert len(concatenated_v2) == 5
909+
assert concatenated_v2.type == tensor_type_v2
910+
911+
# Assert on the full concatenated array
912+
expected = np.vstack([first.to_numpy(), second.to_numpy()])
913+
np.testing.assert_array_equal(concatenated_v2.to_numpy(), expected)
914+
915+
916+
@pytest.mark.skipif(
917+
not _extension_array_concat_supported(),
918+
reason="ExtensionArrays support concatenation only in Pyarrow >= 12.0",
919+
)
920+
def test_arrow_variable_shaped_tensor_type_eq_with_concat():
921+
"""Test that ArrowVariableShapedTensorType __eq__ method works correctly
922+
when concatenating Arrow arrays with variable shaped tensors."""
923+
from ray.data.extensions.tensor_extension import (
924+
ArrowVariableShapedTensorArray,
925+
ArrowVariableShapedTensorType,
926+
)
927+
928+
# Create ArrowVariableShapedTensorType
929+
var_tensor_type = ArrowVariableShapedTensorType(pa.int64(), 2)
930+
931+
# Create arrays with variable-shaped tensors
932+
tensors1 = [np.array([[1, 2], [3, 4]]), np.array([[5, 6, 7], [8, 9, 10]])]
933+
tensors2 = [np.array([[11, 12, 13, 14]]), np.array([[15], [16], [17]])]
934+
935+
arr1 = ArrowVariableShapedTensorArray.from_numpy(tensors1)
936+
arr2 = ArrowVariableShapedTensorArray.from_numpy(tensors2)
937+
938+
assert arr1.type == arr2.type
939+
# Assert commutation
940+
assert var_tensor_type == arr1.type
941+
assert arr1.type == var_tensor_type
942+
943+
# Test concatenation works appropriately
944+
concatenated = pa.concat_arrays([arr1, arr2])
945+
assert len(concatenated) == 4
946+
assert concatenated.type == var_tensor_type
947+
948+
result = concatenated.to_numpy()
949+
expected_shapes = [(2, 2), (2, 3), (1, 4), (3, 1)]
950+
for i, expected_shape in enumerate(expected_shapes):
951+
assert result[i].shape == expected_shape
952+
953+
954+
def test_reverse_order():
955+
"""Test views in reverse order."""
956+
base = np.arange(100, dtype=np.float64)
957+
958+
raveled = np.empty(3, dtype=np.object_)
959+
raveled[0] = base[50:60].ravel()
960+
raveled[1] = base[30:50].ravel()
961+
raveled[2] = base[0:30].ravel()
962+
963+
# Reverse order views should NOT be contiguous
964+
assert not _are_contiguous_1d_views(raveled)
965+
966+
967+
def test_concat_ndarrays_zero_copy():
968+
"""Test that _concat_ndarrays performs zero-copy concatenation when possible."""
969+
# Case 1: Create a base array and contiguous views
970+
base = np.arange(100, dtype=np.int64)
971+
972+
arrs = [base[0:20], base[20:50], base[50:100]]
973+
974+
result = _concat_ndarrays(arrs)
975+
976+
np.testing.assert_array_equal(result, base)
977+
# Verify it's a zero-copy view (shares memory with base)
978+
assert np.shares_memory(result, base)
979+
980+
# Case 2: Verify empty views are skipped
981+
arrs = [base[0:10], base[10:10], base[10:20]] # Empty array
982+
983+
result = _concat_ndarrays(arrs)
984+
expected = np.concatenate([base[0:10], base[10:20]])
985+
986+
np.testing.assert_array_equal(result, expected)
987+
# Verify it's a zero-copy view (shares memory with base)
988+
assert np.shares_memory(result, base)
989+
990+
# Case 3: Singleton ndarray is returned as is
991+
result = _concat_ndarrays([base])
992+
993+
# Should return the same array or equivalent
994+
assert result is base
995+
996+
997+
def test_concat_ndarrays_non_contiguous_fallback():
998+
"""Test that _concat_ndarrays falls back to np.concatenate when arrays aren't contiguous."""
999+
1000+
# Case 1: Non-contiguous arrays
1001+
arr1 = np.arange(10, dtype=np.float32)
1002+
_ = np.arange(1000) # Create gap to prevent contiguity
1003+
arr2 = np.arange(10, 20, dtype=np.float32)
1004+
_ = np.arange(1000) # Create gap to prevent contiguity
1005+
arr3 = np.arange(20, 30, dtype=np.float32)
1006+
1007+
arrs = [arr1, arr2, arr3]
1008+
1009+
result = _concat_ndarrays(arrs)
1010+
1011+
expected = np.concatenate(arrs)
1012+
np.testing.assert_array_equal(result, expected)
1013+
1014+
assert all(not np.shares_memory(result, a) for a in arrs)
1015+
1016+
# Case 2: Non-contiguous arrays (take 2)
1017+
base = np.arange(100, dtype=np.float64)
1018+
1019+
arrs = [base[0:10], base[20:30], base[30:40]] # Gap from 10-20
1020+
1021+
result = _concat_ndarrays(arrs)
1022+
expected = np.concatenate(arrs)
1023+
1024+
np.testing.assert_array_equal(result, expected)
1025+
# Should have created a copy since there's a gap
1026+
assert not np.shares_memory(result, base)
1027+
1028+
1029+
def test_concat_ndarrays_diff_dtypes_fallback():
1030+
"""Different dtypes"""
1031+
1032+
base_int16 = np.arange(50, dtype=np.int16)
1033+
base_int32 = np.arange(50, dtype=np.int32)
1034+
1035+
# Different dtypes should use fallback
1036+
arrs = [base_int16, base_int32]
1037+
1038+
# This should use np.concatenate with type promotion
1039+
result = _concat_ndarrays(arrs)
1040+
expected = np.concatenate(arrs)
1041+
1042+
np.testing.assert_array_equal(result, expected)
1043+
assert result.dtype == expected.dtype
1044+
1045+
1046+
def test_are_contiguous_1d_views_non_raveled():
1047+
"""Test that _are_contiguous_1d_views rejects non-1D arrays."""
1048+
base = np.arange(100, dtype=np.int64).reshape(10, 10)
1049+
1050+
arrs = [
1051+
base[0:2].ravel(), # 1D view
1052+
base[2:4], # 2D array
1053+
]
1054+
1055+
# Should reject because second array is not 1D
1056+
assert not _are_contiguous_1d_views(arrs)
1057+
1058+
1059+
def test_are_contiguous_1d_views_non_c_contiguous():
1060+
"""Test _are_contiguous_1d_views with non-C-contiguous arrays."""
1061+
base = np.arange(100, dtype=np.int64).reshape(10, 10)
1062+
1063+
# Column slices are not C-contiguous
1064+
arrs = [base[:, 0], base[:, 1]]
1065+
1066+
assert not _are_contiguous_1d_views(arrs)
1067+
1068+
1069+
def test_are_contiguous_1d_views_different_bases():
1070+
"""Test _are_contiguous_1d_views with views from different base arrays."""
1071+
base1 = np.arange(50, dtype=np.int64)
1072+
_ = np.arange(1000, dtype=np.int64) # Create gap to prevent contiguity
1073+
base2 = np.arange(50, 100, dtype=np.int64)
1074+
1075+
arrs = [base1, base2]
1076+
1077+
# Different base arrays
1078+
assert not _are_contiguous_1d_views(arrs)
1079+
1080+
1081+
def test_are_contiguous_1d_views_overlapping():
1082+
"""Test _are_contiguous_1d_views with overlapping views."""
1083+
base = np.arange(100, dtype=np.float64)
1084+
1085+
arrs = [base[0:20], base[10:30]] # Overlaps with first
1086+
1087+
# Overlapping views are not contiguous
1088+
assert not _are_contiguous_1d_views(arrs)
1089+
1090+
1091+
def test_concat_ndarrays_complex_views():
1092+
"""Test _concat_ndarrays with complex view scenarios."""
1093+
# Create a 2D array and take contiguous row views
1094+
base_2d = np.arange(100, dtype=np.int64).reshape(10, 10)
1095+
base = base_2d.ravel() # Get 1D view
1096+
1097+
# Take contiguous slices of the 1D view
1098+
arrs = [base[0:30], base[30:60], base[60:100]]
1099+
1100+
result = _concat_ndarrays(arrs)
1101+
np.testing.assert_array_equal(result, base)
1102+
assert np.shares_memory(
1103+
result, base_2d
1104+
) # Should share memory with original 2D array
1105+
1106+
1107+
def test_concat_ndarrays_strided_views():
1108+
"""Test _concat_ndarrays with strided (non-contiguous) views."""
1109+
base = np.arange(100, dtype=np.float64)
1110+
1111+
# Every other element - these are strided views
1112+
arrs = [base[::2], base[1::2]] # Even indices # Odd indices
1113+
1114+
# Strided views are not C-contiguous
1115+
result = _concat_ndarrays(arrs)
1116+
expected = np.concatenate(arrs)
1117+
1118+
np.testing.assert_array_equal(result, expected)
1119+
# Should have created a copy
1120+
assert not np.shares_memory(result, base)
1121+
1122+
8311123
if __name__ == "__main__":
8321124
import sys
8331125

0 commit comments

Comments
 (0)