from functools import reduce, partial
from itertools import batched, chain
from more_itertools import collapse
def reshape(matrix, shape):
"""Change the shape a *matrix*.
If *shape* is an integer, the matrix must be two dimensional
and the shape is interpreted as the desired number of columns:
>>> matrix = [(0, 1), (2, 3), (4, 5)]
>>> cols = 3
>>> list(reshape(matrix, cols))
[(0, 1, 2), (3, 4, 5)]
If *shape* is a tuple, the input matrix can have any number
of dimensions. It will first be flattened and then rebuilt
to the desired shape which can also be multidimensional.
>>> matrix = [(0, 1), (2, 3), (4, 5)]
>>> list(reshape(matrix, (2, 3))) # Make a 2 x 3 matrix
[(0, 1, 2), (3, 4, 5)]
>>> list(reshape(matrix, (6,))) # Make a vector of length six
[0, 1, 2, 3, 4, 5]
>>> list(reshape(matrix, (2, 1, 3, 1))) # Make 2 x 1 x 3 x 1 tensor
[(((0,), (1,), (2,)),), (((3,), (4,), (5,)),)]
"""
if isinstance(shape, int):
return batched(chain.from_iterable(matrix), shape)
_batched = partial(batched, strict=True)
flat = collapse(matrix)
return next(reduce(_batched, reversed(shape), initial=flat))