Adding FlaxNoRepeatNGramLogitsProcessor#29677
Adding FlaxNoRepeatNGramLogitsProcessor#29677ArthurZucker merged 14 commits intohuggingface:mainfrom
Conversation
…est_processor_list_jitted tests
…pdate description of get_previous_ngrams
…ed with jittable version
gante
left a comment
There was a problem hiding this comment.
Thank you for opening the PR 🔥 In general looks good to me, I've added a few questions before approving.
And thank you for keeping the exact same tests as in our PT counterpart, it makes maintenance much simpler 🙌
| ) | ||
|
|
||
| data = jnp.ones((all_update_indices.shape[0],), dtype=jnp.uint16) | ||
| data = data * (jnp.arange(data.shape[0]) < batch_size * (cur_len - (self.ngram_size - 1))) # ignore the n-grams not yet generated |
There was a problem hiding this comment.
Perhaps we could slice input_ids before creating all_update_indices, i.e. input_ids = input_ids[:, :cur_len], and save some time/memory when creating all_update_indices.
Or is the result slower, because cur_length changes each iteration?
There was a problem hiding this comment.
from my experience with jax, in order for the code to work with jit, you cannot use arrays of shapes that are not fixed. This is why i opted to pad the indices to have a known size. There may be more efficient ways to do it but I found that this is working, as opposed to [] slices, or even dynamic slices, because cur_length changes.
There was a problem hiding this comment.
I see :) I made the suggestion based on my past XLA+TF experience -- slicing input_ids = input_ids[:, :cur_len] is allowed there (example)
But then again, JAX is usually stricter. Let's keep as you suggested 🤗
There was a problem hiding this comment.
i think it also works in jax in some situations, but in this case the error when using jit is kind of explicit
IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(None, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).anyway, i took some time to benchmark different ways of doing this kind of operation, and found that in this instance, using jax.lax.fori_loop to update the all_update_indices is significantly faster than using dynamic slices updates (with this + removed useless operations, i measured >10x speedup for the jitted function). there might still be room for improvement
|
The red CI can be fixed by running |
…update indices using jax.lax.fori_loop
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for adding!
I just one request to update a test. I realise this is inherited, but it really should be addressed as the test is v. confusing
| return val.at[i].set( | ||
| jnp.array( | ||
| [ | ||
| b, | ||
| ] | ||
| + [jnp.array(input_ids)[b, pos + j] for j in range(self.ngram_size)] | ||
| ) | ||
| ) |
There was a problem hiding this comment.
nit - can all be one line
| return val.at[i].set( | |
| jnp.array( | |
| [ | |
| b, | |
| ] | |
| + [jnp.array(input_ids)[b, pos + j] for j in range(self.ngram_size)] | |
| ) | |
| ) | |
| return val.at[i].set(jnp.array([b] + [jnp.array(input_ids)[b, pos + j] for j in range(self.ngram_size)])) |
There was a problem hiding this comment.
i think this was formatted like this by the black / ruff formatter. But if that passes the tests, I agree that it is clearer this way
| shape = (batch_size * (seq_len - (self.ngram_size - 1)), self.ngram_size + 1) | ||
| all_update_indices = jax.lax.fori_loop( | ||
| 0, batch_size * (cur_len - (self.ngram_size - 1)), body_fun, jnp.zeros(shape, dtype=input_ids.dtype) | ||
| ) | ||
|
|
||
| # ignore the n-grams not yet generated | ||
| data = ( | ||
| jnp.arange(batch_size * (seq_len - (self.ngram_size - 1))) < batch_size * (cur_len - (self.ngram_size - 1)) | ||
| ).astype("float32") | ||
|
|
||
| return sparse.BCOO((data, all_update_indices), shape=(batch_size,) + (vocab_size,) * self.ngram_size) |
There was a problem hiding this comment.
Rather than calculate (seq_len - (self.ngram_size - 1) and (cur_len - (self.ngram_size - 1) several times, it'll be easier to read and follow this code if they're set to variables and then used
| # 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch | ||
| self.assertListEqual(jnp.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [True, False, False]]) | ||
|
|
||
| # 3-gram would forbid no token at 1st batch and 1st token (0) at 2nd batch | ||
| self.assertListEqual(jnp.isinf(filtered_scores_3_gram).tolist(), [[False, False, False], [True, False, False]]) |
There was a problem hiding this comment.
I realise these are copied from other parts of the library, but the structure here is really quite confusing.
I'm assuming that:
- By batch we mean sample in a minibatch i.e. "1st batch" is
[1, 1, 2, 1] - By tokens we mean token ids
- The output values e.g.
[False, True, True]are across the vocab - When the token ids are referred to e.g. "2nd and 3rd token at 1st batch" what we're referring to are the token ids in the vocabulary
[0, 1, 2]and NOT the 2nd, 3rd token ids in[1, 1, 2, 1]. The comment makes this really confusing by 1) having the same positional values for the sample in the batch as in the vocab 2) saying "at 1st batch". I'd strongly recommend rewriting this to remove this ambiguity.
There was a problem hiding this comment.
the tests were copied with minimal modifications from the pytorch ones so i didn't change these comments. I agree that the descriptions could be clearer.
I think more comprehensive tests could also be a good idea. For example, i didn't see at first that my first code iteration didn't work in the case where a n-gram appears > 1 time, which is not a case that is tested, so my code passed the tests.
There was a problem hiding this comment.
Good point. As these are so close to the PT & TF tests, lets leave as-is for now. If you have a test in mind and are willing to open a follow-up PR to add, I'd be very happy to review :)
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks for adding this!
Thinking back on it - let's not block waiting for the test reworks, this can be done in follow-ups.
Just needs a make fixup run to resolve the quality checks and we should be good to merge!
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Merging as there is a green light and green CI! 🥳 |
What does this PR do?
Adding the no repeat n-gram logits processor to Flax, compatible with jitting.
I also added the test
test_no_repeat_ngram_dist_processor, adapted from the torch one, and added theFlaxNoRepeatNGramLogitsProcessorto thetest_processor_listandtest_processor_list_jitted.All the tests are passing, as
RUN_SLOW=1 pytest -sv tests/generation/test_flax_logits_process.pyprints:Note: in order to work properly within beam search, this processor needs the fix proposed in PR #29636 for the bug discussed in #29635
Let me know if you have any comments or questions regarding this feature.
Who can review?
@gante
@sanchit-gandhi