Skip to content

Speed up zip_broadcast#737

Closed
Masynchin wants to merge 14 commits intomore-itertools:masterfrom
Masynchin:zip_broadcast
Closed

Speed up zip_broadcast#737
Masynchin wants to merge 14 commits intomore-itertools:masterfrom
Masynchin:zip_broadcast

Conversation

@Masynchin
Copy link
Copy Markdown
Contributor

@Masynchin Masynchin commented Jul 23, 2023

Issue reference

No Issues 😎

Changes

Change logic inside zip_broadcast. Now it converts scalars to iterables, so that it can just zip them all together.

Simple benchmark:

  • Before:
Result: 37.9 ms
  • After:
Result: 3.21 ms

@Masynchin
Copy link
Copy Markdown
Contributor Author

It is almost exact same as the original implementation (0926631), but passes new test cases introduced after #561 and #565.

@kalekundert
Copy link
Copy Markdown
Contributor

I don't think this implementation works in the following case:

>>> from more_itertools import zip_broadcast
>>> list(zip_broadcast(1, [1,2], strict=True))
[(1, 1), (1, 2)]

The problem is that repeat(1) will not be considered same size as [1,2], so the strict=True check will fail. It should succeed though, because the scalars are meant to be treated as iterables of the same length as their peers. See #543 for more discussion of this issue.

Unfortunately, it seems like this case is not covered by the tests. I feel bad about that; this is definitely a case that I should've included in #565. I realize now that there are no tests for cases in which zip_broadcast(..., strict=True) does anything except raise an exception, so we should definitely add some.

Maybe a way to improve the performance of zip_broadcast() would be to make custom wrappers for the scalars and the iterables. The iterable wrappers would detect when the iterables become exhausted, then notify the scalar wrappers. The scalar wrappers would behave like repeat() initially, but would stop upon receiving the aforementioned notification. Not totally sure if this would avoid the issues from #561, but it might be worth thinking about.

@bbayles
Copy link
Copy Markdown
Collaborator

bbayles commented Jul 25, 2023

I'll take a PR with more tests, please!

@Masynchin
Copy link
Copy Markdown
Contributor Author

@kalekundert I see your optimization in #740. I was thinking about different approach, and came up with this idea (but not any drafts or PoC).

zip_broadcast uses _zip_equal if strict:

zipper = _zip_equal if strict else zip

Which uses _zip_equal_generator, in case there any "scalar" iterables passed (in this PR version), which in fact repeat(scalar).

# If any one of the iterables didn't have a length, start reading
# them until one runs out.
except TypeError:
return _zip_equal_generator(iterables)

So the issue is that repeat never stops and when any "real" iterable is exhausted, zip_longest yields combo of scalar and _marker. Because of it an UnequalIterablesError raised, but we don't want it.

def _zip_equal_generator(iterables):
for combo in zip_longest(*iterables, fillvalue=_marker):
for val in combo:
if val is _marker:
raise UnequalIterablesError()
yield combo

What if we write another version of _zip_equal and _zip_equal_generator specifically for zip_broadcast? I think about something like this pseudocode:

def _zip_equal_generator(n, *iterables):
    for combo in zip_longest(*iterables, fillvalue=_marker):
        match combo.count(_marker):
            case 0: yield combo
            case n: break
            case _: raise UnequalIterablesError()

Where n is a count of "real" iterables. I have added the check for markers count, if it equals to n, it would mean that all iterables have stopped at the same length, so it would handle the case you suggested in first reply.

What do you think? (Sorry for bad english)

@kalekundert
Copy link
Copy Markdown
Contributor

That seems like a good idea to me. I can't think of any reason why it wouldn't work, and it should be just as fast as your original implementation.

Also, just to be clear, I didn't mean for #740 to preclude more substantial optimizations like this. I just noticed a small inefficiency, and it was easy to fix.

@Masynchin
Copy link
Copy Markdown
Contributor Author

I didn't mean for #740 to preclude more substantial optimizations like this. I just noticed a small inefficiency, and it was easy to fix.

Your optimization is absolutely good, didn't mean to put any offense on it. Hope any other ideas would work too!

@bbayles
Copy link
Copy Markdown
Collaborator

bbayles commented Jul 25, 2023

For both branches, could you please make sure that you've merged in master and gotten the new tests?

@Masynchin
Copy link
Copy Markdown
Contributor Author

For both branches, could you please make sure that you've merged in master and gotten the new tests?

Yep, done!

@Masynchin
Copy link
Copy Markdown
Contributor Author

@kalekundert I have implemented optimization above, seems that it passes both old tests and new ones you added in #739. I am looking for your help reviewing it.

@Masynchin
Copy link
Copy Markdown
Contributor Author

I just copypasted _zip_equal/_zip_equal_generator implementations, slightly rewrote it, and added scalars suffix. Haven't check whether we can optimize it specifically for scalars, I only made sure it passes tests.

@kalekundert
Copy link
Copy Markdown
Contributor

I only have a few very minor comments. Most of them pertain more to _zip_equal() than to your PR, but I included them anyways just because they were things I noticed while reading the code. Feel free to ignore:

  • I think it's probably worth avoiding the duplication of _zip_equal(), e.g. by giving the existing _zip_equal() a keyword-only zip_equal_generator argument that defaults to the standard behavior, but can be changed in cases like this.
  • Alternatively, it might make sense to not use _zip_equal() at all. Most of the time there will be at least one scalar argument, so _zip_equal() will just end up calling _zip_equal_generator() anyways. The extra complexity/code duplication doesn't really seem worth it. This code path also involves raising and catching an exception, which is a relatively expensive operation, but I don't really think the difference would be noticeable here.
  • I'd be tempted to define _marker as a local variable, both in your version of _zip_equal_generator() and the existing version. Not only would this completely eliminate the possibility of this value appearing in a user-provided iterable (which is admittedly a paranoid concern), it would also be ≈5% faster for iterables with more than ≈100 elements due to local variable lookups being more efficient than global ones (see below).
  • The for val in combo: loop in _zip_equal_generator() could be replaced with if _marker in combo:. That would be easier to understand, and ≈10% faster (see below).
  • The for/else logic in _zip_equal() seems unnecessarily complex. Why not just raise the exception immediately, instead of breaking and raising it outside the loop? Of course, this has nothing to do with this PR.

Other than that, everything looks good to me. I assume this version still runs ≈10x faster than the current implementation?


Speed tests for minor optimizations mentioned above:

from itertools import zip_longest
from more_itertools import consume

_marker = object()

args = [
        range(100_000),
        range(100_000),
        range(100_000),
]

def z1(iterables):
    for combo in zip_longest(*iterables, fillvalue=_marker):
        for val in combo:
            if val is _marker:
                raise UnequalIterablesError()
        yield combo

def z2(iterables):
    _marker = object()
    for combo in zip_longest(*iterables, fillvalue=_marker):
        for val in combo:
            if val is _marker:
                raise UnequalIterablesError()
        yield combo

def z3(iterables):
    for combo in zip_longest(*iterables, fillvalue=_marker):
        if _marker in combo:
            raise UnequalIterablesError()
        yield combo

def z4(iterables):
    _marker = object()
    for combo in zip_longest(*iterables, fillvalue=_marker):
        if _marker in combo:
            raise UnequalIterablesError()
        yield combo

import timeit    

for f in [z1, z2, z3, z4]:
    print(
            timeit.timeit(
                stmt=f'consume(f(args))',
                globals=dict(f=f) | globals(),
                number=100,
            )
    )

Output:

2.199860891967546
2.0907241030363366
2.0204972539795563
1.9232789399684407

@Masynchin
Copy link
Copy Markdown
Contributor Author

Masynchin commented Jul 26, 2023

  • I think it's probably worth avoiding the duplication of _zip_equal(), e.g. by giving the existing _zip_equal() a keyword-only zip_equal_generator argument that defaults to the standard behavior, but can be changed in cases like this.

I can do something like this:

def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False):
    ...
    if strict:
+        yield from _zip_equal(
+            *iterables, gen=partial(
+                _zip_equal_generator_scalars, iterables_count
+            )
+        )
    else:
        yield from zip(*iterables)


+def _zip_equal(*iterables, gen=_zip_equal_generator):
    ...
    except TypeError:
+        return gen(iterables)


def _zip_equal_generator_scalars(n, iterables):
    ...

It reverts _zip_equal to its previous positional arguments (without needing to provide n), with new gen keyword. Do partial part looks good? Would it be better with lambda, or just we should rewrite/choose different approach?

  • This code path also involves raising and catching an exception, which is a relatively expensive operation, but I don't really think the difference would be noticeable here.

It is also raised when regular iterable of infinite/undefined length is occured, so how can this check be eliminated?

@kalekundert
Copy link
Copy Markdown
Contributor

kalekundert commented Jul 26, 2023

  • Yeah, that's exactly what I had in mind. And I think partial() is the right way to go.
  • What I meant was something like this:
    def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False):
        ...
        if strict:
    +        # It would also make sense to just put the for loop directly here, with no extra function call.
    +        yield from _zip_equal_generator_scalars(iterables, iterables_count)
        else:
            yield from zip(*iterables)
    Basically, just don't call _zip_equal() at all. The only downside is that you end up using zip_longest() instead of zip() in the case where all the iterables have lengths that are the same. But this will not often be the case.

Don't take my suggestions too seriously; I'm not sure they're all good ideas.

@Masynchin
Copy link
Copy Markdown
Contributor Author

  • I'd be tempted to define _marker as a local variable, both in your version of _zip_equal_generator() and the existing version. Not only would this completely eliminate the possibility of this value appearing in a user-provided iterable (which is admittedly a paranoid concern), it would also be ≈5% faster for iterables with more than ≈100 elements due to local variable lookups being more efficient than global ones (see below).
  • The for val in combo: loop in _zip_equal_generator() could be replaced with if _marker in combo:. That would be easier to understand, and ≈10% faster (see below).
  • The for/else logic in _zip_equal() seems unnecessarily complex. Why not just raise the exception immediately, instead of breaking and raising it outside the loop? Of course, this has nothing to do with this PR.

Good spots! Maybe I should't optimize _zip_equal_generator in this PR, so that you can add this optimization and all the others in optimization-specific PR.

@Masynchin
Copy link
Copy Markdown
Contributor Author

Basically, just don't call _zip_equal() at all. The only downside is that you end up using zip_longest() instead of zip() in the case where all the iterables have lengths that are the same. But this will not often be the case.

I'm not sure they're all good ideas.

Feel free to try, I invited you to collaborate on my fork, so that you can push your commits in this PR. I would agree with any of your decision to add/not add any change, such as above.

@Masynchin
Copy link
Copy Markdown
Contributor Author

@bbayles I have resolved conflicts, is it planned to be merged?

@pochmann
Copy link
Copy Markdown
Contributor

I think about something like this pseudocode:

def _zip_equal_generator(n, *iterables):
    for combo in zip_longest(*iterables, fillvalue=_marker):
        match combo.count(_marker):
            case 0: yield combo
            case n: break
            case _: raise UnequalIterablesError()

That's not safe. Miscounts if there's a non-marker object that equals the marker. Demo:

from unittest.mock import ANY

_marker = object()
for combo in zip('foo', [1, ANY, 3]):
    print(combo.count(_marker))

Output (Attempt This Online!):

0
1
0

@Masynchin
Copy link
Copy Markdown
Contributor Author

@pochmann if it treats ANY as _marker, then it also should fail in the current implementation? Can't test this right now

@Masynchin
Copy link
Copy Markdown
Contributor Author

@pochmann if it treats ANY as _marker, then it also should fail in the current implementation? Can't test this right now

Oh, is it because current one uses is comparison and this PR uses count which uses __equals__ under the hood? If so, can it be fixed with something like sum(1 for o in combo if o is _marker)?

@pochmann
Copy link
Copy Markdown
Contributor

pochmann commented Jul 30, 2023

Yes, counting with is would be correct.

Btw I think I also optimized this, but can't find it right now... Maybe I dismissed it because the current implementation is simpler (especially after prefilling the scalars, which I had also done). But mine might be simpler than the new suggestion. I might try again...

@pochmann
Copy link
Copy Markdown
Contributor

2.199860891967546
2.0907241030363366
2.0204972539795563
1.9232789399684407

What Python version did you use? I think globals got faster in the last few versions. And are those results stable? (I.e., you ran it multiple times and always got very similar results?)

@pochmann
Copy link
Copy Markdown
Contributor

pochmann commented Jul 30, 2023

@Masynchin About your current proposal: I'd rather do it like this, without the extra functions:

def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False):
    ...
    iterables = [repeat(obj) if is_scalar(obj) else obj for obj in objects]

    if not strict:
        yield from zip(*iterables)

    if lengths of the non-scalars are all the same:
        yield from zip(*iterables)
    (or raise UnequalIterablesError if that's the case)

    for combo in zip_longest(*iterables, fillvalue=_marker):
        ...

Advantages:

  • Less code.
  • Higher chance of the are-all-lengths-equal check to succeed, as this doesn't include the repeated scalars in that check.
  • Faster last case (the for combo in case) as it avoids one generator layer. (The other cases could also avoid their generator layer, by making zip_broadcast not be a generator.)

@Masynchin
Copy link
Copy Markdown
Contributor Author

That's not safe. Miscounts if there's a non-marker object that equals the marker

Fixed and added as test

@Masynchin
Copy link
Copy Markdown
Contributor Author

Also, what should I do about flake8 in CI? Same as #742 (comment)

@bbayles
Copy link
Copy Markdown
Collaborator

bbayles commented Jul 31, 2023

Merge in the master branch for the flake8 issue, if you could please.

@Masynchin
Copy link
Copy Markdown
Contributor Author

I found that in this PR UnequalIterableError may be raised with incorrect iterable index. Let me fix that before the merge

@kalekundert
Copy link
Copy Markdown
Contributor

2.199860891967546
2.0907241030363366
2.0204972539795563
1.9232789399684407

What Python version did you use? I think globals got faster in the last few versions. And are those results stable? (I.e., you ran it multiple times and always got very similar results?)

I used version 3.10.0, and I ran it a bunch of times. IIRC, it's not a stable effect when there are only ≈10 items per iterable, but it's very stable by the time you get to ≈1000. That said, this is a really micro optimization, and I just pointed it out because I noticed it.

@Masynchin
Copy link
Copy Markdown
Contributor Author

I found that in this PR UnequalIterableError may be raised with incorrect iterable index. Let me fix that before the merge

Fixed

I'd rather do it like this, without the extra functions

Done

@pochmann
Copy link
Copy Markdown
Contributor

pochmann commented Aug 1, 2023

I found that in this PR UnequalIterableError may be raised with incorrect iterable index

Can you tell what that was, give an example? I don't see it, and the fix commit changes a lot of code.


iterables_count = sum(1 for obj in objects if not is_scalar(obj))
iterables = list(filterfalse(is_scalar, objects))
iterables_count = ilen(iterables)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a list, just use len. And I don't think you need the count in the variable. You only use it once or twice, the first time you can just check the list instead, and the second time you can just call len there.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a list, just use len.

My bad, I had filterfalse in mind while adding ilen.

...and the second time you can just call len there.

elif markers == iterables_count:

Should I use len here? It would be more performant if I check the length once and not every step of for-loop

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I use len here?

Yes.

It would be more performant if I check the length once and not every step of for-loop

No. You get there at most once.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You get there at most once.

Am I missing something?

Left side: code with debug print where length eval would happen, right side: run of proposed code with 10 debug prints

Copy link
Copy Markdown
Contributor

@pochmann pochmann Aug 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your print is at the wrong place. You only get to the len call if the if condition is false. Prepend the print to the elif instead: elif print(...) or condition:.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. If if statement failed, then for-loop cancels in both elif or else branches. Thanks for notice!

@Masynchin
Copy link
Copy Markdown
Contributor Author

I found that in this PR UnequalIterableError may be raised with incorrect iterable index

Can you tell what that was, give an example?

Consider consume(zip_broadcast(0, [1], [1, 2], strict=True)). There is one scalar and two real iterables. In the current implementation it checks for lengths of this iterables, ignoring indexes of scalars. In this PR it do the same now, and raises:

UnequalIterablesError: Iterables have different lengths: index 0 has length 1; index 1 has length 2

Before the fix it would be index 1 has length 1; index 2 has length 2. I think it was bug in my implementation, but if you count it as a feature, I can revert it back.

@pochmann
Copy link
Copy Markdown
Contributor

pochmann commented Aug 1, 2023

Before the fix it would be index 1 has length 1; index 2 has length 2.

Ah, ok. Not sure which is better, indexes referring to all arguments or just to iterables.

Wasn't it index 0 has length 1; index 2 has length 2, though? (The index 0 is hardcoded.) That would really be wrong.

@Masynchin
Copy link
Copy Markdown
Contributor Author

Wasn't it index 0 has length 1; index 2 has length 2, though?

I just rerun consume(zip_broadcast(0, [1], [1, 2], strict=True)) on hash before fix and it raises UnequalIterablesError: Iterables have different lengths without any length indexes 😬

@Masynchin
Copy link
Copy Markdown
Contributor Author

I run this benchmark on Python 3.11.3:

import timeit
from more_itertools import consume, zip_broadcast

N = 1_000
G = globals()

t1 = timeit.timeit("consume(zip_broadcast(1, 2, [1] * 100_000))", number=N, globals=G)
t2 = timeit.timeit("consume(zip_broadcast(1, 2, [1] * 200_000, [2, 3] * 100_000))", number=N, globals=G)
t3 = timeit.timeit("consume(zip_broadcast(1, 2, [1] * 100_000, strict=True))", number=N, globals=G)

print(t1, t2, t3, sep="\n")

On master (266ebdc) and this PR (2dd6fe2), here are the results:

  • Master
17.083783166483045
40.894872582517564
21.31543183233589
  • PR
2.304440625011921
5.877900958992541
2.28413129132241

Almost 10x speed up. There is no unresolved questions, can we merge it?

@pochmann
Copy link
Copy Markdown
Contributor

pochmann commented Aug 5, 2023

I think none of those test your slow case, can you test that as well?

@bbayles
Copy link
Copy Markdown
Collaborator

bbayles commented Aug 31, 2023

Is there still a test to add here?

@Masynchin
Copy link
Copy Markdown
Contributor Author

Is there still a test to add here?

After #739 been merged and I tweaked this PR to pass this tests, no other problems was found. The only thing requested is regression benchmarks, but I am currently too busy with other stuff. My thought is that it can only regress if caller provides only one iterable and N scalars. I would be happy if anyone could verify this

@bbayles
Copy link
Copy Markdown
Collaborator

bbayles commented Jan 5, 2026

I think nobody is eager to finish this off, so closing. Thanks for the contribution nonetheless!

@bbayles bbayles closed this Jan 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants