[tx] Support top_k sampling#680
Conversation
|
@gemini-code-assist review |
There was a problem hiding this comment.
Code Review
This pull request successfully implements top_k sampling, including the core logic and API integration. The changes are well-tested. I've provided a few suggestions to enhance performance and improve the readability and robustness of the tests.
There was a problem hiding this comment.
Code Review
This pull request implements top_k sampling. The changes look good and include necessary logic in the generation pipeline and new tests. I've provided a few suggestions to improve code maintainability and performance. Specifically, I've suggested refactoring duplicated test code, simplifying assertions, and optimizing the apply_top_k function for better performance.
| # Values below threshold should be -inf | ||
| assert jnp.isinf(filtered[0]) and filtered[0] < 0 | ||
| assert jnp.isinf(filtered[1]) and filtered[1] < 0 | ||
| assert jnp.isinf(filtered[2]) and filtered[2] < 0 | ||
| # Top 2 values should be unchanged | ||
| assert filtered[3] == 4.0 | ||
| assert filtered[4] == 5.0 |
There was a problem hiding this comment.
The assertions to check the filtered logits can be simplified by comparing against an expected array. This makes the test more concise and easier to read.
| # Values below threshold should be -inf | |
| assert jnp.isinf(filtered[0]) and filtered[0] < 0 | |
| assert jnp.isinf(filtered[1]) and filtered[1] < 0 | |
| assert jnp.isinf(filtered[2]) and filtered[2] < 0 | |
| # Top 2 values should be unchanged | |
| assert filtered[3] == 4.0 | |
| assert filtered[4] == 5.0 | |
| expected = jnp.array([-jnp.inf, -jnp.inf, -jnp.inf, 4.0, 5.0]) | |
| assert jnp.array_equal(filtered, expected) |
9bf5c08 to
bbc5056
Compare
|
@tyler-griggs lmk if you have any comments for top_k Rebased to most recent main |
bbc5056 to
1766982
Compare
|
@pcmoritz @tyler-griggs lmk if you have thoughts, happy to rebase again to most recent before you merge |
| # Values below threshold should be -inf | ||
| assert jnp.isinf(filtered[0]) and filtered[0] < 0 | ||
| assert jnp.isinf(filtered[1]) and filtered[1] < 0 | ||
| assert jnp.isinf(filtered[2]) and filtered[2] < 0 | ||
| # Top 2 values should be unchanged | ||
| assert filtered[3] == 4.0 | ||
| assert filtered[4] == 5.0 |
Resolve merge conflicts to combine: - top_k sampling support from feature branch - stop_strings support from main branch Both features are now available in SamplingParams. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Tyler Griggs <131809874+tyler-griggs@users.noreply.github.com>
…nto feature/implement-top-k
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request successfully implements top_k sampling. The changes are well-structured, touching the API, type definitions, and the core generator logic. The implementation of apply_top_k_batch is efficient and JIT-friendly. The new tests in test_api.py and test_generator.py provide good coverage for the new functionality.
I have a few suggestions to improve the assertions in the tests to make them more robust and concise. Overall, this is a solid contribution.
pcmoritz
left a comment
There was a problem hiding this comment.
Thanks a lot for implementing this @agolajko , I implemented @tyler-griggs 's suggestion of a fast path if there is no top_k filtering, and also used jax.lax.top_k so we don't need to do the sorting :)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
|
I'll merge this now, but there are some improvements that could be made as a follow up:
|
Implements the
top_kpart of the sampling API requested in #533Tests
test_top_k_filteringintest_generator.py: tests the core logic of the samplingtest_sample_top_kintest_api.py: checks the API can be called with thetop_kparameterDiscussed with @pcmoritz on slack