Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 75 additions & 8 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -2767,8 +2767,17 @@ cdef class RecordBatch(_Tabular):

Parameters
----------
names : list of str
List of new column names.
names : list[str] or dict[str, str]
List of new column names or mapping of old column names to new column names.

If a mapping of old to new column names is passed, then all columns which are
found to match a provided old column name will be renamed to the new column name.
If any column names are not found in the mapping, a KeyError will be raised.

Raises
------
KeyError
If any of the column names passed in the names mapping do not exist.

Returns
-------
Expand All @@ -2789,13 +2798,38 @@ cdef class RecordBatch(_Tabular):
----
n: [2,4,5,100]
name: ["Flamingo","Horse","Brittle stars","Centipede"]
>>> new_names = {"n_legs": "n", "animals": "name"}
>>> batch.rename_columns(new_names)
pyarrow.RecordBatch
n: int64
name: string
----
n: [2,4,5,100]
name: ["Flamingo","Horse","Brittle stars","Centipede"]
"""
cdef:
shared_ptr[CRecordBatch] c_batch
vector[c_string] c_names

for name in names:
c_names.push_back(tobytes(name))
if isinstance(names, list):
for name in names:
c_names.push_back(tobytes(name))
elif isinstance(names, dict):
idx_to_new_name = {}
for name, new_name in names.items():
indices = self.schema.get_all_field_indices(name)

if not indices:
raise KeyError("Column {!r} not found".format(name))

for index in indices:
idx_to_new_name[index] = new_name

for i in range(self.num_columns):
new_name = idx_to_new_name.get(i, self.column_names[i])
c_names.push_back(tobytes(new_name))
else:
raise TypeError(f"names must be a list or dict not {type(names)!r}")

with nogil:
c_batch = GetResultValue(self.batch.RenameColumns(move(c_names)))
Expand Down Expand Up @@ -5122,8 +5156,17 @@ cdef class Table(_Tabular):

Parameters
----------
names : list of str
List of new column names.
names : list[str] or dict[str, str]
List of new column names or mapping of old column names to new column names.

If a mapping of old to new column names is passed, then all columns which are
found to match a provided old column name will be renamed to the new column name.
If any column names are not found in the mapping, a KeyError will be raised.

Raises
------
KeyError
If any of the column names passed in the names mapping do not exist.

Returns
-------
Expand All @@ -5144,13 +5187,37 @@ cdef class Table(_Tabular):
----
n: [[2,4,5,100]]
name: [["Flamingo","Horse","Brittle stars","Centipede"]]
>>> new_names = {"n_legs": "n", "animals": "name"}
>>> table.rename_columns(new_names)
pyarrow.Table
n: int64
name: string
----
n: [[2,4,5,100]]
name: [["Flamingo","Horse","Brittle stars","Centipede"]]
"""
cdef:
shared_ptr[CTable] c_table
vector[c_string] c_names

for name in names:
c_names.push_back(tobytes(name))
if isinstance(names, list):
for name in names:
c_names.push_back(tobytes(name))
elif isinstance(names, dict):
idx_to_new_name = {}
for name, new_name in names.items():
indices = self.schema.get_all_field_indices(name)

if not indices:
raise KeyError("Column {!r} not found".format(name))

for index in indices:
idx_to_new_name[index] = new_name

for i in range(self.num_columns):
c_names.push_back(tobytes(idx_to_new_name.get(i, self.schema[i].name)))
else:
raise TypeError(f"names must be a list or dict not {type(names)!r}")

with nogil:
c_table = GetResultValue(self.table.RenameColumns(move(c_names)))
Expand Down
37 changes: 37 additions & 0 deletions python/pyarrow/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1595,6 +1595,43 @@ def test_table_rename_columns(cls):
expected = cls.from_arrays(data, names=['eh', 'bee', 'sea'])
assert t2.equals(expected)

message = "names must be a list or dict not <class 'str'>"
with pytest.raises(TypeError, match=message):
table.rename_columns('not a list')


@pytest.mark.parametrize(
('cls'),
[
(pa.Table),
(pa.RecordBatch)
]
)
def test_table_rename_columns_mapping(cls):
data = [
pa.array(range(5)),
pa.array([-10, -5, 0, 5, 10]),
pa.array(range(5, 10))
]
table = cls.from_arrays(data, names=['a', 'b', 'c'])
assert table.column_names == ['a', 'b', 'c']

expected = cls.from_arrays(data, names=['eh', 'b', 'sea'])
t1 = table.rename_columns({'a': 'eh', 'c': 'sea'})
t1.validate()
assert t1 == expected

# Test renaming duplicate column names
table = cls.from_arrays(data, names=['a', 'a', 'c'])
expected = cls.from_arrays(data, names=['eh', 'eh', 'sea'])
t2 = table.rename_columns({'a': 'eh', 'c': 'sea'})
t2.validate()
assert t2 == expected

# Test column not found
with pytest.raises(KeyError, match=r"Column 'd' not found"):
table.rename_columns({'a': 'eh', 'd': 'sea'})


def test_table_flatten():
ty1 = pa.struct([pa.field('x', pa.int16()),
Expand Down