Conversation
- added tests
pcuenca
left a comment
There was a problem hiding this comment.
This is fantastic, @jkrukowski!
Let me try it out and check the implementation against the current one in transformers in case there are some details that could be incorporated, but this looks great already.
As a side comment, we could potentially implement the costly cumsum operation in Core ML as part of the model conversion, or using a Core ML pipeline. But using Accelerate should be more than enough for now!
pcuenca
left a comment
There was a problem hiding this comment.
I tested it and it works fine! As expected, it's a bit slow but we can try to optimize later. Additionally, top-k and top-p could potentially coexist as pointed out below, but we can also handle that in a new PR unless you want to tackle it now :)
| if config.topK > 0 { | ||
| let topK = Math.topK(arr: logits, k: config.topK) | ||
| nextToken = Math.sample(indexes: topK.indexes, probs: topK.probs) | ||
| } else if config.topP < 1.0 { |
There was a problem hiding this comment.
If my understanding of this is correct, top-k can coexist with top-p: https://github.com/huggingface/transformers/blob/42017d82baa083da2bee3055fdac80c81ee97b8a/src/transformers/generation/utils.py#L805-L808
However, it could make sense to merge this PR now and making them coexist in a future one. What do you think?
There was a problem hiding this comment.
I'd say let's merge it now, seems logical to create a separate PR with a common interface to these two
| fatalError("topP not implemented yet") | ||
| fatalError("not implemented yet") |
There was a problem hiding this comment.
If we make top-k compatible with top-p, we'd do a single sample call on the selected tokens and remove this fatalError.
In this PR
I've compared 2 different implementations here https://github.com/jkrukowski/topp -- looks like using Accelerate to compute a cumulative sum gives it a seed boost.