-
Notifications
You must be signed in to change notification settings - Fork 311
Generalized outer product #729
Copy link
Copy link
Closed
Labels
pr-welcomeWe are open to PRs that fix this issue - leave a note if you're working on it!We are open to PRs that fix this issue - leave a note if you're working on it!
Description
Description
Hello and thank you for creating more-itertools! It's great.
I'd like to propose adding a generalized outer product function. It's similar to the itertools.product function, except that it preserves the nested structure of the 2D matrix and allow you to apply a function to every pair of items. This function is surprisingly versatile!
References
This function exists mainly in array languages. It also exists in Numpy but it's cumbersome to use because you need to create a ufunc. Ufuncs also don't play well with mypy.
Examples
from typing import TypeVar, Iterable, Callable
T = TypeVar('T')
U = TypeVar('U')
V = TypeVar('V')
def outer_product(func: Callable[[T, U], V], xs: Iterable[T], ys: Iterable[U]) -> Iterable[list[V]]:
ys = list(ys) # Consume ys once
return ([func(x, y) for y in ys] for x in xs)
if __name__ == '__main__':
from operator import mul
from more_itertools import dotproduct, matmul, transpose
from itertools import compress
# Multiplication table
assert list(outer_product(mul, [1, 2, 3], [1, 2, 3])) == [[1, 2, 3],
[2, 4, 6],
[3, 6, 9]]
# Matrix multiplication
# Source: https://futhark-lang.org/examples/outer-product.html
# You have to admit that outer_product(dotproduct, A, transpose(B)) is a pretty neat way to write matmul!
A = [[1, 2], [3, 4], [5, 6]]
B = [[7, 8, 9], [10, 11, 12]]
assert list(outer_product(dotproduct, A, transpose(B))) == list(matmul(A, B))
# [[27, 30, 33],
# [61, 68, 75],
# [95, 106, 117]]
# Greetings
greetings = ['Hello', 'Goodbye']
names = ['Alice', 'Bob']
assert list(outer_product(lambda g, n: f'{g}, {n}!', greetings, names)) == [['Hello, Alice!', 'Hello, Bob!'],
['Goodbye, Alice!', 'Goodbye, Bob!']]
# Distance matrix
coords = [[0.68446066, 0.09628254],
[0.05818767, 0.26179779],
[0.76983281, 0.65376925],
[0.6214879 , 0.7856759 ]]
distance = lambda p1, p2: round(((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5, 2)
assert list(outer_product(distance, coords, coords)) == [[0.0, 0.65, 0.56, 0.69],
[0.65, 0.0, 0.81, 0.77],
[0.56, 0.81, 0.0, 0.2],
[0.69, 0.77, 0.2, 0.0]]
# Filter out sentences that contain certain words
# I've actually used something like this in production code!
sentences = ["The quick brown fox jumps over the lazy dog.",
"The truth will set you free.",
"The only thing we have to fear is fear itself."]
words = ['quick', 'fear']
mask = outer_product(lambda s, w: s.find(w) == -1, sentences, words)
filtered_sentences = list(compress(sentences, map(all, mask)))
assert filtered_sentences == ['The truth will set you free.']This function can also be defined for higher dimensions.
Thank you for your consideration!
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
pr-welcomeWe are open to PRs that fix this issue - leave a note if you're working on it!We are open to PRs that fix this issue - leave a note if you're working on it!